From 0495fc54359b7fc38fc4f950d791022efd6a91a5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:13:30 +0000 Subject: [PATCH 001/753] Bump ubuntu in /tensorflow/tools/tf_sig_build_dockerfiles Bumps ubuntu from `0950623` to `104ae83`. --- updated-dependencies: - dependency-name: ubuntu dependency-version: '22.04' dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile index 8d35977d14a987..b9d06f956f6d2a 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile +++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:22.04@sha256:09506232a8004baa32c47d68f1e5c307d648fdd59f5e7eaa42aaf87914100db3 as builder +FROM ubuntu:22.04@sha256:104ae83764a5119017b8e8d6218fa0832b09df65aae7d5a6de29a85d813da2fb as builder ################################################################################ # Install devtoolset build dependencies From 800ff3bf850590f12a3940b33a1a196bb98b5d6f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 09:57:00 +0000 Subject: [PATCH 002/753] Bump the github-actions group with 5 updates Bumps the github-actions group with 5 updates: | Package | From | To | | --- | --- | --- | | [actions/checkout](https://github.com/actions/checkout) | `5.0.0` | `6.0.0` | | [google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml](https://github.com/google/osv-scanner-action) | `2.2.4` | `2.3.0` | | [actions/setup-python](https://github.com/actions/setup-python) | `6.0.0` | `6.1.0` | | [peter-evans/create-pull-request](https://github.com/peter-evans/create-pull-request) | `7.0.8` | `7.0.9` | | [github/codeql-action](https://github.com/github/codeql-action) | `4.31.2` | `4.31.6` | Updates `actions/checkout` from 5.0.0 to 6.0.0 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/08c6903cd8c0fde910a37f88322edcfb5dd907a8...1af3b93b6815bc44a9784bd300feb67ff0d1eeb3) Updates `google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml` from 2.2.4 to 2.3.0 - [Release notes](https://github.com/google/osv-scanner-action/releases) - [Commits](https://github.com/google/osv-scanner-action/compare/v2.2.4...v2.3.0) Updates `actions/setup-python` from 6.0.0 to 6.1.0 - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/e797f83bcb11b83ae66e0230d6156d7c80228e7c...83679a892e2d95755f2dac6acb0bfd1e9ac5d548) Updates `peter-evans/create-pull-request` from 7.0.8 to 7.0.9 - [Release notes](https://github.com/peter-evans/create-pull-request/releases) - [Commits](https://github.com/peter-evans/create-pull-request/compare/271a8d0340265f705b14b6d32b9829c1cb33d45e...84ae59a2cdc2258d6fa0732dd66352dddae2a412) Updates `github/codeql-action` from 4.31.2 to 4.31.6 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/0499de31b99561a6d14a36a5f662c2a54f91beee...fe4161a26a8629af62121b670040955b330f9af2) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: 6.0.0 dependency-type: direct:production update-type: version-update:semver-major dependency-group: github-actions - dependency-name: google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml dependency-version: 2.3.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions - dependency-name: actions/setup-python dependency-version: 6.1.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: github-actions - dependency-name: peter-evans/create-pull-request dependency-version: 7.0.9 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions - dependency-name: github/codeql-action dependency-version: 4.31.6 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: github-actions ... Signed-off-by: dependabot[bot] --- .github/workflows/arm-cd.yml | 4 ++-- .github/workflows/arm-ci-extended-cpp.yml | 4 ++-- .github/workflows/arm-ci-extended.yml | 4 ++-- .github/workflows/arm-ci.yml | 2 +- .github/workflows/cffconvert.yml | 2 +- .github/workflows/issue-on-pr-rollback.yml | 2 +- .github/workflows/osv-scanner-scheduled.yml | 2 +- .github/workflows/pylint-presubmit.yml | 4 ++-- .github/workflows/release-branch-cherrypick.yml | 4 ++-- .github/workflows/scorecards-analysis.yml | 4 ++-- .github/workflows/update-rbe.yml | 4 ++-- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index 2e3912041d9cf2..5430fc1c8151e8 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -52,12 +52,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: ref: 'nightly' - name: Checkout repository for releases (skipped for nightly) if: ${{ github.event_name == 'push' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Build and test pip wheel shell: bash run: | diff --git a/.github/workflows/arm-ci-extended-cpp.yml b/.github/workflows/arm-ci-extended-cpp.yml index 54903a6998b090..09085e814daba1 100644 --- a/.github/workflows/arm-ci-extended-cpp.yml +++ b/.github/workflows/arm-ci-extended-cpp.yml @@ -50,12 +50,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Build binary and run C++ tests shell: bash run: | diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 2235cfc2d986da..94237fcaa6cca5 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -51,12 +51,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Build binary and run python tests on nightly for all python versions shell: bash run: | diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index a141bdd4676852..12d8ab4a2cf719 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -47,7 +47,7 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Build binary and run python tests shell: bash run: | diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 6421e08ccf0839..de578ffec96327 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out a copy of the repository - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@4cf11baa70a673bfdf9dad0acc7ee33b3f4b6084 # v2.0.0 diff --git a/.github/workflows/issue-on-pr-rollback.yml b/.github/workflows/issue-on-pr-rollback.yml index d5e0661a5f356b..1d548e9204e563 100644 --- a/.github/workflows/issue-on-pr-rollback.yml +++ b/.github/workflows/issue-on-pr-rollback.yml @@ -33,7 +33,7 @@ jobs: startsWith(github.event.head_commit.message, 'Rollback of PR #') steps: - name: Checkout repo - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Create a new Github Issue uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index 07896a48470753..984dead9db7388 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.2.4" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.3.0" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index 59068d9d86f45d..483cf5bfc0addf 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Get file changes id: get_file_changes uses: trilom/file-changes-action@a6ca26c14274c33b15e6499323aac178af06ad4b # v1.2.4 @@ -38,7 +38,7 @@ jobs: run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index 69e03a040ae1a2..fc643c92d304d1 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -45,7 +45,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: ref: ${{ github.event.inputs.release_branch }} - name: Get some helpful info for formatting @@ -58,7 +58,7 @@ jobs: echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 + uses: peter-evans/create-pull-request@84ae59a2cdc2258d6fa0732dd66352dddae2a412 # v7.0.9 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index e635c4cd8ccc88..ce2d7075019b5d 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -41,7 +41,7 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 with: persist-credentials: false @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@0499de31b99561a6d14a36a5f662c2a54f91beee # v3.29.5 + uses: github/codeql-action/upload-sarif@fe4161a26a8629af62121b670040955b330f9af2 # v3.29.5 with: sarif_file: results.sarif diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index a8dba883f5ff14..d2cc83b7f5c2c2 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -30,7 +30,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Update the RBE Configs run: | function map() { @@ -130,7 +130,7 @@ jobs: map sigbuild-r2.17-clang-python3.11 2.17-python3.11 map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 + uses: peter-evans/create-pull-request@84ae59a2cdc2258d6fa0732dd66352dddae2a412 # v7.0.9 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation From 37562c3d83b7366276341bfdbd83b8a7ed5d97ce Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 10:12:39 +0000 Subject: [PATCH 003/753] Bump ubuntu from `66460d5` to `c35e29c` in /tensorflow/tools/gcs_test Bumps ubuntu from `66460d5` to `c35e29c`. --- updated-dependencies: - dependency-name: ubuntu dependency-version: '24.04' dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- tensorflow/tools/gcs_test/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile index b5fbef19051f8a..19958cb6478765 100644 --- a/tensorflow/tools/gcs_test/Dockerfile +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:24.04@sha256:66460d557b25769b102175144d538d88219c077c678a49af4afca6fbfc1b5252 +FROM ubuntu:24.04@sha256:c35e29c9450151419d9448b0fd75374fec4fff364a27f176fb458d472dfc9e54 LABEL maintainer="Shanqing Cai " From 515af6fc2e13085c5d71dd75426f881ae7418a20 Mon Sep 17 00:00:00 2001 From: 1ndig0 <1090891928@qq.com> Date: Tue, 2 Dec 2025 15:52:07 +0800 Subject: [PATCH 004/753] Change begin and size types to include int16 Updated the type annotations for begin and size parameters to include int16. --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9f6644b4342ada..94dadf91a0e18d 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -982,8 +982,8 @@ def slice(input_, begin, size, name=None): Args: input_: A `Tensor`. - begin: An `int32` or `int64` `Tensor`. - size: An `int32` or `int64` `Tensor`. + begin: An `int16`, `int32` or `int64` `Tensor`. + size: An `int16`, `int32` or `int64` `Tensor`. name: A name for the operation (optional). Returns: From 2fb9073f2205e48bf54263e60d846a3e4ab8d39d Mon Sep 17 00:00:00 2001 From: "guozhong.zhuang" Date: Tue, 2 Dec 2025 13:04:41 -0800 Subject: [PATCH 005/753] [oneDNN] Improve oneDNN primitive caching performance --- tensorflow/core/util/BUILD | 1 + tensorflow/core/util/mkl_util.h | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 3acd07c02fadf8..72cd0b7751e2cc 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -298,6 +298,7 @@ filegroup( "mkl_heuristics.h", "mkl_util.h", "onednn_env_vars.h", + "@com_google_absl//absl/container:flat_hash_map", "@local_xla//xla/tsl/util:onednn_util_hdrs", ], visibility = ["//tensorflow/core:__pkg__"], diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index d15ec3034a93c9..a3a5381583a196 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "oneapi/dnnl/dnnl.hpp" #include "oneapi/dnnl/dnnl_threadpool.hpp" #include "tensorflow/core/framework/op_kernel.h" @@ -1963,7 +1964,7 @@ class LRUCache { size_t capacity_; // The cache, a map from string key to a LRU entry. - std::unordered_map cache_; + absl::flat_hash_map cache_; // The LRU list of entries. // The front of the list contains the key of the most recently accessed From ce6f59526c772e345faf09d3d2f60a3078e2e331 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 6 Dec 2025 04:52:02 +0000 Subject: [PATCH 006/753] Bump urllib3 in /ci/official/requirements_updater/numpy1_requirements Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.5.0 to 2.6.0. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.5.0...2.6.0) --- updated-dependencies: - dependency-name: urllib3 dependency-version: 2.6.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- .../numpy1_requirements/requirements_lock_3_10.txt | 6 +++--- .../numpy1_requirements/requirements_lock_3_11.txt | 6 +++--- .../numpy1_requirements/requirements_lock_3_12.txt | 6 +++--- .../numpy1_requirements/requirements_lock_3_9.txt | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt index 1bef2b2f7903df..898ea6c0418532 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt @@ -729,9 +729,9 @@ typing-extensions==4.14.1 \ # -r ci/official/requirements_updater/requirements.in # optree # rich -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt index 7bc734c2624710..eae965757fad3b 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt @@ -728,9 +728,9 @@ typing-extensions==4.14.1 \ # via # -r ci/official/requirements_updater/requirements.in # optree -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt index 8d9d9dc47fc5d7..ca6904da19ebbe 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt @@ -728,9 +728,9 @@ typing-extensions==4.14.1 \ # via # -r ci/official/requirements_updater/requirements.in # optree -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt index 41eb61f5557d7f..e34567660cc5f7 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt @@ -725,9 +725,9 @@ typing-extensions==4.14.1 \ # -r ci/official/requirements_updater/requirements.in # optree # rich -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ From a027f53ab0dfa6e0e3ee86aa71165e809c990ff3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 00:26:49 -0800 Subject: [PATCH 007/753] Automated Code Change PiperOrigin-RevId: 841630902 --- third_party/xla/xla/tools/ptx_opt/BUILD | 1 + third_party/xla/xla/tools/ptx_opt/ptx_opt.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/tools/ptx_opt/BUILD b/third_party/xla/xla/tools/ptx_opt/BUILD index 262a448e9e2f9e..dcdbb9cc1162b4 100644 --- a/third_party/xla/xla/tools/ptx_opt/BUILD +++ b/third_party/xla/xla/tools/ptx_opt/BUILD @@ -22,6 +22,7 @@ xla_cc_binary( ], deps = [ "//xla:debug_options_flags", + "//xla:xla_proto_cc", "//xla/service/gpu/llvm_gpu_backend:load_ir_module", "//xla/service/gpu/llvm_gpu_backend:nvptx_backend", "//xla/stream_executor:device_description", diff --git a/third_party/xla/xla/tools/ptx_opt/ptx_opt.cc b/third_party/xla/xla/tools/ptx_opt/ptx_opt.cc index df00cb8039c253..64114733254f84 100644 --- a/third_party/xla/xla/tools/ptx_opt/ptx_opt.cc +++ b/third_party/xla/xla/tools/ptx_opt/ptx_opt.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/tsl/util/command_line_flags.h" +#include "xla/xla.pb.h" #include "tsl/platform/init_main.h" namespace xla::gpu::nvptx { From bfa9f8bf5e3557d5544b606f9d730162db448b8c Mon Sep 17 00:00:00 2001 From: Shaogang Wang Date: Mon, 8 Dec 2025 00:42:02 -0800 Subject: [PATCH 008/753] PR #34802: [XLA:GPU] Add buffer type information for GpuExecutable memory allocation profile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34802 📝 Summary of Changes Add buffer type information for GpuExecutable memory allocation profile 🎯 Justification Give us some insights on which allocation is frequently changed, and useful for command buffer optimization. 🚀 Kind of Contribution 📚 Documentation Copybara import of the project: -- 8930cae81077e508dc69c3bf367f03d0439c6205 by Shawn Wang : Update stable address profile to include allocation type Merging this change closes #34802 PiperOrigin-RevId: 841634857 --- third_party/xla/xla/service/gpu/gpu_executable.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index 96d1b4ca13c2d9..d72fdeddab4fe3 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -1062,7 +1062,14 @@ absl::Status GpuExecutable::ExecuteThunks( } module_allocations_[executor][i] = buffer_allocations.GetDeviceAddress(i); - VLOG(5) << "Gpu address changed for module " << module_name_; + const BufferAllocation& allocation = + buffer_assignment_->GetAllocation(i); + const char* allocation_type = + allocation.is_entry_computation_parameter() ? "parameter" + : allocation.maybe_live_out() ? "live-out" + : "temp"; + VLOG(5) << "Gpu address changed for module " << module_name_ + << ", allocation " << i << " (" << allocation_type << ")"; } } } From ea054af52030b0d5af9e5b5dfd66c5d853affa7f Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 8 Dec 2025 00:53:54 -0800 Subject: [PATCH 009/753] Remove HloProfilePrinterData and HloProfileIndexMap from Executable. This change removes the HloProfilePrinterData and HloProfileIndexMap members from the Executable base class and updates all derived classes and call sites to reflect this change. The profiling data is no longer stored within the Executable. PiperOrigin-RevId: 841638347 --- .../backends/interpreter/executable_base.cc | 3 +- third_party/xla/xla/service/BUILD | 1 - .../service/cpu/cpu_aot_compilation_result.cc | 4 +-- .../service/cpu/cpu_aot_compilation_result.h | 18 ---------- .../xla/xla/service/cpu/cpu_compiler.cc | 23 +++--------- .../xla/xla/service/cpu/cpu_executable.cc | 21 ++++------- .../xla/xla/service/cpu/cpu_executable.h | 4 --- third_party/xla/xla/service/executable.h | 35 ------------------- 8 files changed, 14 insertions(+), 95 deletions(-) diff --git a/third_party/xla/xla/backends/interpreter/executable_base.cc b/third_party/xla/xla/backends/interpreter/executable_base.cc index d8a9ac91c7d39f..7ba92f41d87701 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.cc +++ b/third_party/xla/xla/backends/interpreter/executable_base.cc @@ -55,8 +55,7 @@ namespace interpreter { InterpreterExecutableBase::InterpreterExecutableBase( std::unique_ptr hlo_module) - : Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + : Executable(std::move(hlo_module)) {} absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 252abf3c4d8f98..fb4b60aedea3f1 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1526,7 +1526,6 @@ cc_library( deps = [ ":buffer_assignment", ":computation_layout", - ":hlo_execution_profile", ":hlo_module_config", ":hlo_proto_cc", ":maybe_owning_device_memory", diff --git a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc index e422891c24ec34..31ca1d590cd292 100644 --- a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc +++ b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.cc @@ -196,8 +196,8 @@ CpuAotCompilationResult::LoadExecutable( cpu_executable, CpuExecutable::Create(std::move(function_library_), std::move(buffer_assignment), std::move(module), - std::move(*thunks), std::move(constants), nullptr, - nullptr, target_machine_options)); + std::move(*thunks), std::move(constants), + target_machine_options)); // Dump computation proto state and buffer assignment for // GetCompiledMemoryStats results. diff --git a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.h b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.h index e6589fb1787da5..4817200999814f 100644 --- a/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.h +++ b/third_party/xla/xla/service/cpu/cpu_aot_compilation_result.h @@ -110,24 +110,6 @@ class CpuAotCompilationResult : public AotCompilationResult { TargetMachineOptionsProto target_machine_options = TargetMachineOptionsProto()); - [[deprecated( - "HloProfilePrinterData is not used anymore. Use the other Create " - "method instead.")]] static absl:: - StatusOr> - Create(const HloModule* hlo_module, - const BufferAssignment* buffer_assignment, - absl::string_view function_name, - std::vector obj_files, - std::vector symbols, const ThunkSequence& thunks, - std::unique_ptr function_library, - std::unique_ptr hlo_profile_printer_data, - TargetMachineOptionsProto target_machine_options = - TargetMachineOptionsProto()) { - return Create(hlo_module, buffer_assignment, function_name, - std::move(obj_files), std::move(symbols), thunks, - std::move(function_library), target_machine_options); - } - ~CpuAotCompilationResult() override = default; absl::StatusOr SerializeAsString() const override { diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 895657789c891f..a6117a8169ddc3 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -2019,11 +2019,10 @@ CpuCompiler::CompileCpuExecutable( TF_ASSIGN_OR_RETURN( auto cpu_executable, - CpuExecutable::Create( - std::move(function_library), std::move(assignment), std::move(module), - std::move(thunks), std::move(constants), - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map), - std::move(target_machine_options))); + CpuExecutable::Create(std::move(function_library), std::move(assignment), + std::move(module), std::move(thunks), + std::move(constants), + std::move(target_machine_options))); // Save object files to be able to export them to AOT compilation // result. @@ -2243,12 +2242,6 @@ CpuCompiler::CompileAheadOfTimeThunks( const ThunkSequence& thunk_sequence = cpu_executable->thunks().thunk_sequence(); - std::unique_ptr executable_hlo_profile_printer_data = - cpu_executable->module().config().hlo_profiling_enabled() - ? std::make_unique( - cpu_executable->hlo_profile_printer_data()) - : nullptr; - if (cpu_executable->obj_files().size() > 1) { return Internal( "Expected at most one object file for AOT compilation, but got %d", @@ -2266,7 +2259,6 @@ CpuCompiler::CompileAheadOfTimeThunks( cpu_executable->module_name(), std::move(obj_files), cpu_executable->get_compiled_symbols_proto(), thunk_sequence, std::move(*cpu_executable).consume_function_library(), - std::move(executable_hlo_profile_printer_data), cpu_executable->target_machine_options().ToProto()); } @@ -2299,12 +2291,6 @@ absl::StatusOr> CpuCompiler::Export( std::vector compiled_symbols_proto = cpu_executable->get_compiled_symbols_proto(); - std::unique_ptr executable_hlo_profile_printer_data = - cpu_executable->module().config().hlo_profiling_enabled() - ? std::make_unique( - cpu_executable->hlo_profile_printer_data()) - : nullptr; - TF_ASSIGN_OR_RETURN(auto compiled_symbols, GetCompiledSymbolsFromProto(compiled_symbols_proto)); @@ -2319,7 +2305,6 @@ absl::StatusOr> CpuCompiler::Export( cpu_executable->module_name(), std::move(obj_files), std::move(compiled_symbols_proto), *thunk_sequence, std::move(function_library), - std::move(executable_hlo_profile_printer_data), cpu_executable->target_machine_options().ToProto()); } diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 6e0cf855e34f97..6bb3a695e9523e 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -88,16 +88,13 @@ absl::StatusOr> CpuExecutable::Create( std::unique_ptr assignment, std::unique_ptr hlo_module, ThunkSequence thunks, std::vector constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map, TargetMachineOptions target_machine_options) { VLOG(2) << "Create CpuExecutable from a thunk sequence; module=" << hlo_module->name() << ", constants=" << constants.size(); - std::unique_ptr executable(new CpuExecutable( - std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map), std::move(assignment), - std::move(target_machine_options))); + std::unique_ptr executable( + new CpuExecutable(std::move(hlo_module), std::move(assignment), + std::move(target_machine_options))); executable->function_library_ = std::move(function_library); ThunkExecutor::Options thunk_executor_options; @@ -129,14 +126,10 @@ absl::StatusOr> CpuExecutable::Create( return executable; } -CpuExecutable::CpuExecutable( - std::unique_ptr hlo_module, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map, - std::unique_ptr assignment, - TargetMachineOptions target_machine_options) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map)), +CpuExecutable::CpuExecutable(std::unique_ptr hlo_module, + std::unique_ptr assignment, + TargetMachineOptions target_machine_options) + : Executable(std::move(hlo_module)), assignment_(std::move(assignment)), target_machine_options_(std::move(target_machine_options)) { if (assignment_ && has_module()) { diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index ee590e472dbf83..ebb97baf217e47 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -62,8 +62,6 @@ class CpuExecutable : public Executable { std::unique_ptr assignment, std::unique_ptr hlo_module, ThunkSequence thunks, std::vector constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map, TargetMachineOptions target_machine_options); ~CpuExecutable() override; @@ -246,8 +244,6 @@ class CpuExecutable : public Executable { std::string entry_function_name_; CpuExecutable(std::unique_ptr hlo_module, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map, std::unique_ptr assignment, TargetMachineOptions target_machine_options); CpuExecutable(const CpuExecutable&) = delete; diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index db444230abe342..e59ac39a932d44 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include @@ -37,7 +36,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_execution_profile.h" #include "xla/service/hlo_module_config.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/service/service_executable_run_options.h" @@ -265,20 +263,6 @@ class Executable { // doesn't need it for execution. explicit Executable(std::shared_ptr hlo_module) : hlo_module_(std::move(hlo_module)) {} - - // TODO(b/172012028): Remove this constructor. - // The hlo_module parameter may be nullptr, if the given executable type - // doesn't need it for execution. - explicit Executable( - std::shared_ptr hlo_module, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) - : hlo_module_(std::move(hlo_module)), - hlo_profile_printer_data_(std::move(hlo_profile_printer_data)), - hlo_profile_index_map_(std::move(hlo_profile_index_map)) { - CHECK_EQ(hlo_profile_printer_data_.get() == nullptr, - hlo_profile_index_map_.get() == nullptr); - } virtual ~Executable() = default; // Enqueues the compilation result on the provided stream, passing the given @@ -344,22 +328,6 @@ class Executable { const ServiceExecutableRunOptions* run_options, std::vector arguments); - const HloProfilePrinterData& hlo_profile_printer_data() const { - CHECK(hlo_profiling_enabled()); - return *hlo_profile_printer_data_; - } - - const HloProfileIndexMap& hlo_profile_index_map() const { - CHECK(hlo_profiling_enabled()); - return *hlo_profile_index_map_; - } - - // Returns whether this executable was compiled with HLO profilings support - // enabled. If not, the caller should not expect an hlo_execution_profile - // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { - return hlo_profile_printer_data_ != nullptr; - } HloModule& module() const { CHECK(hlo_module_ != nullptr); @@ -477,9 +445,6 @@ class Executable { // execution. int64_t execution_count_ = 0; - std::unique_ptr hlo_profile_printer_data_; - std::unique_ptr hlo_profile_index_map_; - // A map from kernel name to relevant kernel stats. ModuleStats module_stats_; From a6e0e6fea1a7910728a32763f3e8a4722d5491d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eusebio=20Dur=C3=A1n=20Monta=C3=B1a?= Date: Mon, 8 Dec 2025 00:58:04 -0800 Subject: [PATCH 010/753] Add missing BUILD dependencies, and remove unused ones PiperOrigin-RevId: 841639400 --- third_party/xla/xla/backends/autotuner/BUILD | 1 - third_party/xla/xla/backends/cpu/BUILD | 3 -- .../xla/xla/backends/cpu/autotuner/BUILD | 4 --- .../xla/xla/backends/cpu/codegen/BUILD | 6 ---- .../xla/xla/backends/cpu/codegen/dot/BUILD | 2 -- .../xla/backends/cpu/codegen/elemental/BUILD | 2 -- .../xla/backends/cpu/codegen/emitters/BUILD | 1 - .../cpu/codegen/tiled/transforms/BUILD | 1 - .../xla/xla/backends/cpu/collectives/BUILD | 5 --- .../xla/xla/backends/cpu/runtime/BUILD | 35 ++++--------------- .../xla/backends/cpu/runtime/xnnpack/BUILD | 2 -- .../xla/backends/cpu/runtime/ynnpack/BUILD | 8 ----- .../xla/xla/backends/cpu/testlib/BUILD | 3 -- third_party/xla/xla/backends/cpu/tests/BUILD | 8 ----- .../xla/xla/backends/gpu/autotuner/BUILD | 4 --- .../xla/xla/backends/gpu/codegen/BUILD | 1 - .../backends/gpu/codegen/emitters/ir/BUILD | 1 - .../xla/xla/backends/gpu/codegen/tools/BUILD | 2 -- .../xla/xla/backends/gpu/codegen/triton/BUILD | 17 --------- .../xla/xla/backends/gpu/runtime/BUILD | 12 ------- .../xla/xla/backends/interpreter/BUILD | 1 - .../xla/xla/backends/profiler/gpu/BUILD | 8 ----- third_party/xla/xla/codegen/BUILD | 8 ----- third_party/xla/xla/codegen/emitters/BUILD | 6 ---- third_party/xla/xla/codegen/tiling/BUILD | 6 ---- 25 files changed, 7 insertions(+), 140 deletions(-) diff --git a/third_party/xla/xla/backends/autotuner/BUILD b/third_party/xla/xla/backends/autotuner/BUILD index 87fa53104ca953..b9c1637e90d582 100644 --- a/third_party/xla/xla/backends/autotuner/BUILD +++ b/third_party/xla/xla/backends/autotuner/BUILD @@ -181,7 +181,6 @@ xla_cc_test( "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:statusor", - "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@com_google_protobuf//:any_cc_proto", diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index b124f4f72962ea..7d9e7ad114aba0 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -54,7 +54,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], @@ -182,7 +181,6 @@ cc_library( srcs = ["xnn_gemm_config.cc"], hdrs = ["xnn_gemm_config.h"], deps = [ - "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/codegen:target_machine_features", "//xla/backends/cpu/runtime:dot_dims", @@ -208,7 +206,6 @@ cc_library( "//xla/service:pattern_matcher", "//xla/tsl/platform:statusor", "@XNNPACK", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/backends/cpu/autotuner/BUILD b/third_party/xla/xla/backends/cpu/autotuner/BUILD index 89b7f87c5cdce7..e5bc7335a1a082 100644 --- a/third_party/xla/xla/backends/cpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/cpu/autotuner/BUILD @@ -88,7 +88,6 @@ cc_library( hdrs = ["xnnpack_backend.h"], deps = [ ":cpu_codegen_backend", - "//xla:status_macros", "//xla:util", "//xla/backends/autotuner:codegen_backend", "//xla/backends/cpu:xnn_fusion_options_proto_cc", @@ -102,7 +101,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:casts", ], ) @@ -203,7 +201,5 @@ xla_cc_test( "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/backends/cpu/codegen/BUILD b/third_party/xla/xla/backends/cpu/codegen/BUILD index 4753a80c341de5..db9b0a2cfbd267 100644 --- a/third_party/xla/xla/backends/cpu/codegen/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/BUILD @@ -94,7 +94,6 @@ cc_library( srcs = ["ir_compiler.cc"], hdrs = ["ir_compiler.h"], deps = [ - ":cpu_features", ":kernel_api_ir_builder", ":polynomial_approximations", "//xla:util", @@ -133,7 +132,6 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TargetParser", - "@local_tsl//tsl/platform:platform_port", ], ) @@ -351,7 +349,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", @@ -557,9 +554,7 @@ cc_library( hdrs = ["object_loader.h"], deps = [ ":compiled_function_library", - ":contiguous_section_memory_manager", ":execution_engine", - ":jit_memory_mapper", "//xla/backends/cpu/runtime:function_library", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", @@ -776,7 +771,6 @@ py_strict_test( "//third_party/py/numpy", "//xla/backends/cpu/testlib", "//xla/codegen/testlib", - "//xla/python:xla_extension", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", ], diff --git a/third_party/xla/xla/backends/cpu/codegen/dot/BUILD b/third_party/xla/xla/backends/cpu/codegen/dot/BUILD index 067d7d21b6df1d..97db03ab0cfcaa 100644 --- a/third_party/xla/xla/backends/cpu/codegen/dot/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/dot/BUILD @@ -15,7 +15,6 @@ cc_library( "//xla:util", "//xla/backends/cpu/codegen:kernel_api_ir_builder", "//xla/backends/cpu/codegen:target_machine_features", - "//xla/codegen:kernel_definition", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", "//xla/codegen:llvm_kernel_source", @@ -25,7 +24,6 @@ cc_library( "//xla/service:hlo_module_config", "//xla/service/cpu:dot_op_emitter", "//xla/service/llvm_ir:ir_array", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD index bdce35a2907d7e..f31b161c01e632 100644 --- a/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/elemental/BUILD @@ -35,7 +35,6 @@ cc_library( "//xla/service/cpu:backend_config_proto_cc", "//xla/service/cpu:ir_emitter", "//xla/service/llvm_ir:ir_array", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -104,7 +103,6 @@ xla_cc_test( ":elemental_kernel_emitter", "//xla:xla_data_proto_cc", "//xla/codegen:kernel_definition", - "//xla/codegen:kernel_emitter", "//xla/codegen:llvm_kernel_source", "//xla/hlo/analysis:alias_info", "//xla/hlo/analysis:hlo_ordering", diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD index d0ac6c73095272..a7f171019e92cf 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/BUILD @@ -124,6 +124,5 @@ xla_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMToLLVMIRTranslation", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:casts", ], ) diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD index d81c0557967508..ee920c56af1ac2 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/transforms/BUILD @@ -114,7 +114,6 @@ cc_library( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:VectorDialect", ], ) diff --git a/third_party/xla/xla/backends/cpu/collectives/BUILD b/third_party/xla/xla/backends/cpu/collectives/BUILD index abec1d790f9b90..c1103b1457180a 100644 --- a/third_party/xla/xla/backends/cpu/collectives/BUILD +++ b/third_party/xla/xla/backends/cpu/collectives/BUILD @@ -76,13 +76,11 @@ xla_cc_test( ":cpu_clique_key", ":cpu_cliques", ":in_process_collectives", - "//xla:util", "//xla/core/collectives:rank_id", "//xla/runtime:device_id", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], ) @@ -144,7 +142,6 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:rendezvous", "//xla/stream_executor:device_address", - "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/math:math_util", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", @@ -272,7 +269,6 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_address", - "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", @@ -338,7 +334,6 @@ cc_library( "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_address", - "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 1ee4f81b7db5f6..8d68edd2d6e12e 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -1,6 +1,6 @@ load("//xla:xla.default.bzl", "xla_cc_test", "xla_internal") load("//xla/service/cpu:build_defs.bzl", "runtime_copts") -load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") load("//xla/tsl/platform:build_config.bzl", "tf_proto_library") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") @@ -105,14 +105,11 @@ cc_library( "//xla:util", "//xla/runtime:work_group", "//xla/stream_executor:device_address", - "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", ], @@ -370,13 +367,10 @@ cc_library( "//xla/core/collectives:communicator", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], @@ -496,14 +490,11 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], @@ -524,12 +515,10 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], @@ -550,13 +539,10 @@ cc_library( "//xla/core/collectives:communicator", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], @@ -580,13 +566,10 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -752,6 +735,11 @@ cc_library( "dot_lib_s8.cc", ], hdrs = ["dot_lib.h"], + tags = if_google([ + # Prevent build_cleaner from adding a dependency on eigen_contraction_kernel.h, see comment + # on `:dot_lib_onednn` below. + "ignore_for_dep=third_party/tensorflow/compiler/xla/tsl/framework/contraction/eigen_contraction_kernel.h", + ]), deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", @@ -799,14 +787,11 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen3", ], ) @@ -967,7 +952,6 @@ cc_library( "//xla/runtime:work_group", "//xla/service:buffer_assignment", "//xla/stream_executor:device_address", - "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", @@ -995,7 +979,6 @@ xla_cc_test( deps = [ ":buffer_allocations", ":function_library", - ":kernel", ":kernel_c_api", ":kernel_thunk", ":thunk", @@ -1003,7 +986,7 @@ xla_cc_test( "//xla:literal_util", "//xla/runtime:work_group", "//xla/service:buffer_assignment", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", @@ -1099,7 +1082,6 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1227,7 +1209,6 @@ cc_library( "//xla/stream_executor:device_address", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", ], @@ -1424,8 +1405,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", - "//xla/tsl/platform:logging", - "//xla/tsl/platform:test", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", diff --git a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD index 9cd96a5e90b63b..dc32dac687585b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/xnnpack/BUILD @@ -160,7 +160,6 @@ cc_library( deps = [ ":xnn_interop", "//xla:shape_util", - "//xla:util", "//xla/backends/cpu/runtime:thunk", "//xla/runtime:buffer_use", "//xla/runtime:object_pool", @@ -172,7 +171,6 @@ cc_library( "//xla/tsl/platform:statusor", "@XNNPACK", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", diff --git a/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD b/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD index 5172563e23ca0c..f5b22b22c19d8c 100644 --- a/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/ynnpack/BUILD @@ -22,10 +22,7 @@ cc_library( "//xla/backends/cpu/runtime:work_queue", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", "@local_tsl//tsl/profiler/lib:traceme", @@ -72,15 +69,10 @@ cc_library( ":slinky_threadpool", ":ynn_interop", "@XNNPACK//ynnpack:ynnpack_h", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", - "@slinky//slinky/base:thread_pool", ], ) diff --git a/third_party/xla/xla/backends/cpu/testlib/BUILD b/third_party/xla/xla/backends/cpu/testlib/BUILD index e4d57e372c0ef0..e926426a322d37 100644 --- a/third_party/xla/xla/backends/cpu/testlib/BUILD +++ b/third_party/xla/xla/backends/cpu/testlib/BUILD @@ -159,7 +159,6 @@ tsl_pybind_extension( "//xla/codegen:llvm_kernel_source", "//xla/codegen:mlir_kernel_source", "//xla/codegen/testlib:kernel_runner", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/runtime:work_group", "//xla/service:buffer_assignment", @@ -192,7 +191,6 @@ xla_cc_test( "//xla/codegen:llvm_kernel_source", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor:launch_dim", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", @@ -200,7 +198,6 @@ xla_cc_test( "@com_google_googletest//:gtest", "@llvm-project//llvm:JITLink", "@llvm-project//llvm:ir_headers", - "@local_tsl//tsl/platform:casts", ], ) diff --git a/third_party/xla/xla/backends/cpu/tests/BUILD b/third_party/xla/xla/backends/cpu/tests/BUILD index 0e63cc568a6459..435200a13e65d2 100644 --- a/third_party/xla/xla/backends/cpu/tests/BUILD +++ b/third_party/xla/xla/backends/cpu/tests/BUILD @@ -20,14 +20,9 @@ xla_test( tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", "//xla/hlo/parser:hlo_parser", - "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", - "//xla/tests:client_library_test_runner_mixin", "//xla/tests:hlo_pjrt_interpreter_reference_mixin", "//xla/tests:hlo_pjrt_test_base", - "//xla/tests:literal_test_util", "//xla/tsl/platform:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -41,14 +36,11 @@ xla_test( tags = ["test_migrated_to_hlo_runner_pjrt"], deps = [ "//xla:error_spec", - "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", "//xla/tests:hlo_pjrt_interpreter_reference_mixin", "//xla/tests:hlo_pjrt_test_base", "//xla/tsl/platform:test", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:platform_port", ], ) diff --git a/third_party/xla/xla/backends/gpu/autotuner/BUILD b/third_party/xla/xla/backends/gpu/autotuner/BUILD index d9eb72d2b71296..3d1f9c93001508 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/gpu/autotuner/BUILD @@ -40,7 +40,6 @@ xla_cc_test( srcs = ["gpu_codegen_backend_test.cc"], deps = [ ":gpu_codegen_backend", - "//xla:xla_proto_cc", "@com_google_googletest//:gtest_main", ], ) @@ -225,7 +224,6 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", @@ -573,7 +571,6 @@ xla_test( ":native_emitter", "//xla/backends/autotuner:codegen_backend", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:compiler", "//xla/service:executable", @@ -607,7 +604,6 @@ cc_library( ":fission_backend", ":triton", "//xla/backends/autotuner:codegen_backend", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/service:compiler", "//xla/service/gpu/transforms:dot_algorithm_rewriter", diff --git a/third_party/xla/xla/backends/gpu/codegen/BUILD b/third_party/xla/xla/backends/gpu/codegen/BUILD index 2be8a0247fbe27..62f99605665307 100644 --- a/third_party/xla/xla/backends/gpu/codegen/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/BUILD @@ -295,7 +295,6 @@ cc_library( "//xla/backends/gpu/codegen/emitters:transpose", "//xla/backends/gpu/codegen/triton:fusion", "//xla/codegen:ir_emission_utils", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_traversal", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD index 20b5dec1c59a74..4c9961033ebc17 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/ir/BUILD @@ -116,7 +116,6 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", - "@llvm-project//mlir:CallOpInterfaces", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD index dfe8f7f9a8100f..df59b5ebed0cec 100644 --- a/third_party/xla/xla/backends/gpu/codegen/tools/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/tools/BUILD @@ -32,7 +32,6 @@ cc_library( "//xla:status_macros", "//xla/backends/gpu/codegen:fusions", "//xla/backends/gpu/codegen/emitters:emitter_base", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", @@ -55,7 +54,6 @@ xla_cc_binary( deps = [ ":test_lib", "//xla/codegen/tools:test_lib", - "//xla/hlo/analysis:symbolic_expr", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 8e365aba6a0b59..19587ed10c4d4a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -259,7 +259,6 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/backends/gpu/codegen/triton/ir:triton_xla", "//xla/codegen/emitters:elemental_hlo_to_mlir", "//xla/codegen/emitters/ir:xla", "//xla/codegen/tiling:symbolic_tile_analysis", @@ -559,7 +558,6 @@ xla_test( "//xla:error_spec", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:pattern_matcher_gmock", @@ -612,19 +610,13 @@ xla_test( "//xla:autotuning_proto_cc", "//xla:error_spec", "//xla:xla_proto_cc", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/service/gpu/transforms:nest_gemm_fusion", "//xla/stream_executor:device_description", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", @@ -711,7 +703,6 @@ xla_test( "no_mac", ], deps = [ - ":fusion_emitter", ":support", ":test_utils", ":xtile_compiler", @@ -724,12 +715,6 @@ xla_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/codegen/tiling:symbolic_tile_analysis", - "//xla/codegen/tiling:tiled_hlo_computation", - "//xla/codegen/tiling:tiled_hlo_instruction", - "//xla/codegen/tiling:tiled_hlo_schedule", - "//xla/codegen/tiling:tiling_specification", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:verified_hlo_module", @@ -787,7 +772,6 @@ cc_library( "//xla/service/gpu:gpu_float_support", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:target_constants", "//xla/service/gpu/model:block_level_parameters", "//xla/service/gpu/model:triton_emitter_constraints", "//xla/stream_executor:device_description", @@ -1010,7 +994,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", - "@llvm-project//mlir:LLVMDialect", ], ) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 1ba07e7410f2c2..192c99de9902a6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -1667,10 +1667,8 @@ cc_library( hdrs = ["collective_params.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":collective_clique_requests", "//xla:executable_run_options", "//xla:util", - "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/runtime:device_id", "//xla/service:computation_placer", @@ -1681,9 +1679,7 @@ cc_library( "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -1700,7 +1696,6 @@ cc_library( "//xla/backends/gpu/collectives:gpu_clique", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/backends/gpu/collectives:gpu_cliques", - "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/runtime:device_id", @@ -1710,9 +1705,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -2153,11 +2146,8 @@ cc_library( "//xla:executable_run_options", "//xla:status_macros", "//xla:util", - "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_cliques", "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", - "//xla/core/collectives:rank_id", "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", "//xla/runtime:buffer_use", @@ -2171,7 +2161,6 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/gtl:int_type", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -2179,7 +2168,6 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 6b9515523cf0c1..6bf5957323a49c 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -168,7 +168,6 @@ cc_library( "//xla/stream_executor:stream_executor_common", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_stream", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 30e6a1cfa72cb4..fefd3b9b992862 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -68,7 +68,6 @@ cc_library( ":rocm_tracer_utils", "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/platform:env_time", - "//xla/tsl/platform:errors", "//xla/tsl/profiler/backends/cpu:annotation_stack", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -201,9 +200,7 @@ xla_test( ":cupti_wrapper", ":mock_cupti", "//xla/tsl/profiler/utils:time_utils", - "@com_google_absl//absl/memory", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test", ], ) @@ -362,7 +359,6 @@ cc_library( ], deps = [ ":cupti_collector", - ":cupti_interface", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], @@ -756,11 +752,7 @@ xla_cc_test( ":ondevice_event_exporter", "//xla/tsl/profiler/backends/gpu:ondevice_event_receiver", "//xla/tsl/profiler/backends/gpu:ondevice_trace_event", - "//xla/tsl/profiler/utils:xplane_builder", - "//xla/tsl/profiler/utils:xplane_schema", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/third_party/xla/xla/codegen/BUILD b/third_party/xla/xla/codegen/BUILD index f58bf888d460bb..ddcc8810df0a6f 100644 --- a/third_party/xla/xla/codegen/BUILD +++ b/third_party/xla/xla/codegen/BUILD @@ -57,8 +57,6 @@ cc_library( srcs = ["llvm_kernel_source.cc"], hdrs = ["llvm_kernel_source.h"], deps = [ - ":kernel_definition", - ":kernel_emitter", ":kernel_source", "//xla/service/llvm_ir:llvm_util", "@llvm-project//llvm:Core", @@ -77,7 +75,6 @@ cc_library( deps = [ ":kernel_source", ":kernel_spec", - "//xla/tsl/platform:logging", ], ) @@ -86,8 +83,6 @@ cc_library( srcs = ["mlir_kernel_source.cc"], hdrs = ["mlir_kernel_source.h"], deps = [ - ":kernel_definition", - ":kernel_emitter", ":kernel_source", "//xla:util", "//xla/hlo/analysis:symbolic_expr", @@ -117,7 +112,6 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], @@ -136,8 +130,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:status_matchers", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/codegen/emitters/BUILD b/third_party/xla/xla/codegen/emitters/BUILD index 3bf04998441f0b..5fc4861574b9ed 100644 --- a/third_party/xla/xla/codegen/emitters/BUILD +++ b/third_party/xla/xla/codegen/emitters/BUILD @@ -257,7 +257,6 @@ cc_library( "//xla/codegen:kernel_spec", "//xla/codegen/emitters/ir:xla", "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/runtime:work_dimensions", "//xla/runtime:work_group", @@ -287,7 +286,6 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/analysis:symbolic_expr", "//xla/runtime:work_cluster", "//xla/runtime:work_dimensions", "//xla/runtime:work_group", @@ -311,7 +309,6 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/codegen:hlo_fusion_spec", - "//xla/codegen:kernel_definition", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", "//xla/codegen:mlir_kernel_source", @@ -327,7 +324,6 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -365,7 +361,6 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/codegen:hlo_fusion_spec", - "//xla/codegen:kernel_definition", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", "//xla/codegen:mlir_kernel_source", @@ -425,7 +420,6 @@ cc_library( "//xla:util", "//xla/codegen:hlo_fusion_spec", "//xla/codegen:ir_emission_utils", - "//xla/codegen:kernel_definition", "//xla/codegen:kernel_emitter", "//xla/codegen:kernel_spec", "//xla/codegen:mlir_kernel_source", diff --git a/third_party/xla/xla/codegen/tiling/BUILD b/third_party/xla/xla/codegen/tiling/BUILD index a373ed7c1ced60..07ebadad917774 100644 --- a/third_party/xla/xla/codegen/tiling/BUILD +++ b/third_party/xla/xla/codegen/tiling/BUILD @@ -306,12 +306,7 @@ cc_library( hdrs = ["tiling_specification.h"], deps = [ ":constraint_expression", - ":symbolic_tiled_hlo_instruction", - ":tiled_hlo_computation", - "//xla/hlo/analysis:indexing_analysis", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_traversal", - "//xla/service:instruction_fusion", "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -321,7 +316,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:IR", ], ) From 21f33aec5ab61b897851145109f476c17bb64e94 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 01:05:38 -0800 Subject: [PATCH 011/753] Update GraphDef version to 2435. PiperOrigin-RevId: 841641925 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 97ae6af69c56ae..5448bf12c3dcfe 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2434 // Updated: 2025/12/7 +#define TF_GRAPH_DEF_VERSION 2435 // Updated: 2025/12/8 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From e5a79a127258429ea35a27e3b4c16f511887cdfc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 01:06:32 -0800 Subject: [PATCH 012/753] compat: Update forward compatibility horizon to 2025-12-08 PiperOrigin-RevId: 841642257 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 77eb63a7551ed6..019f2360af662e 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 12, 7) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 12, 8) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From a79c3d3a39db173699271673c6f33eb6d4209cff Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 01:12:17 -0800 Subject: [PATCH 013/753] Automated Code Change PiperOrigin-RevId: 841644074 --- .../xla/xla/backends/gpu/codegen/emitters/emitter_base.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc index 91b52d4d011ee6..f171a5cb6b4f33 100644 --- a/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc +++ b/third_party/xla/xla/backends/gpu/codegen/emitters/emitter_base.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" From 59e985cfa334b8bb32d07b0f74d0503eb0c03bf6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 01:14:14 -0800 Subject: [PATCH 014/753] [XLA:GPU] Filter out multi-gpu tests since current GPU L4 action has only single GPU. PiperOrigin-RevId: 841644586 --- third_party/xla/build_tools/ci/build.py | 130 ++++-------------- .../xla/build_tools/ci/golden_commands.txt | 28 ++-- .../xla/xla/stream_executor/cuda/BUILD | 1 - third_party/xla/xla/tests/BUILD | 15 -- .../xla/xla/tests/collective_ops_e2e_test.cc | 8 ++ .../xla/tests/collective_ops_e2e_test_base.h | 5 + 6 files changed, 53 insertions(+), 134 deletions(-) diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 20d77da3bee540..0b6e4dbcbf8822 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -273,6 +273,17 @@ def _tag_filters_for_compute_capability( return tag_filters +nvidia_gpu_filters = ( + "-no_oss", + "requires-gpu-nvidia", + "gpu", + "-rocm-only", + "-oneapi-only", +) + +single_nvidia_gpu_filters = nvidia_gpu_filters + ("-multi_gpu",) + + def nvidia_gpu_build_with_compute_capability( *, type_: BuildType, @@ -285,21 +296,8 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, configs=configs, - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) - + extra_gpu_tags, - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + test_tag_filters=single_nvidia_gpu_filters + extra_gpu_tags, + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", "//xla/tsl:ci_build": True, @@ -510,21 +508,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, configs=("warnings", "rbe_linux_cuda_nvcc", "hermetic_cuda_umd"), - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=75), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", "//xla/tsl:ci_build": True, @@ -542,21 +528,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, configs=("warnings", "rbe_linux_cuda_nvcc", "hermetic_cuda_umd"), - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=75), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", "//xla/tsl:ci_build": True, @@ -575,21 +549,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", configs=("warnings", "rbe_linux_cuda_nvcc", "hermetic_cuda_umd"), target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=75), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", "//xla/tsl:ci_build": True, @@ -607,21 +569,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", configs=("warnings", "rbe_linux_cuda_nvcc", "hermetic_cuda_umd"), target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=75), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", "//xla/tsl:ci_build": True, @@ -640,21 +590,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", configs=(), target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=100), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", # Use User Mode and Kernel Mode Drivers pre-installed on the system. @@ -675,21 +613,9 @@ def nvidia_gpu_build_with_compute_capability( repo="openxla/xla", configs=(), target_patterns=_XLA_GPU_PRESUBMIT_BENCHMARKS_DEFAULT_TARGET_PATTERNS, - test_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ) + test_tag_filters=single_nvidia_gpu_filters + _tag_filters_for_compute_capability(compute_capability=100), - build_tag_filters=( - "-no_oss", - "requires-gpu-nvidia", - "gpu", - "-rocm-only", - "-oneapi-only", - ), + build_tag_filters=single_nvidia_gpu_filters, options={ "run_under": "//build_tools/ci:parallel_gpu_execute", # Use User Mode and Kernel Mode Drivers pre-installed on the system. @@ -932,11 +858,7 @@ def nvidia_gpu_build_with_compute_capability( Build( type_=BuildType.TENSORFLOW_LINUX_X86_GPU_L4_GITHUB_ACTIONS, repo="tensorflow/tensorflow", - configs=( - "release_gpu_linux", - "rbe_linux_cuda", - "hermetic_cuda_umd" - ), + configs=("release_gpu_linux", "rbe_linux_cuda", "hermetic_cuda_umd"), target_patterns=( "//tensorflow/compiler/...", "-//tensorflow/compiler/tf2tensorrt/...", diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index f5e914157ec888..e067ee9ecc80dd 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -55,44 +55,44 @@ bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_CPU_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_A4_224_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_A4_224_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_A4_224_VCPU_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm100-only,requires-gpu-sm60,requires-gpu-sm70,requires-gpu-sm80,requires-gpu-sm90,requires-gpu-sm100,-requires-gpu-amd,-requires-gpu-intel --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=10 --repo_env=HERMETIC_CUDA_VERSION=12.8.0 --repo_env=HERMETIC_CUDNN_VERSION=9.8.0 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_A4_224_VCPU_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_L4_16_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_L4_16_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_L4_16_VCPU_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_L4_16_VCPU_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_L4_48_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --@local_config_cuda//cuda:include_cuda_libs=False --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_L4_48_VCPU_BENCHMARK_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_L4_48_VCPU_PRESUBMIT_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu -bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu +bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/tools/multihost_hlo_runner:hlo_runner_main_gpu //xla/tools:compute_xspace_stats_main_gpu bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_L4_48_VCPU_PRESUBMIT_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_L4_GITHUB_ACTIONS nvidia-smi -parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... -bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... +parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... +bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,-oneapi-only,-multi_gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-sm100,-requires-gpu-sm100-only,-requires-gpu-amd,-requires-gpu-intel --config=warnings --config=rbe_linux_cuda_nvcc --config=hermetic_cuda_umd --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --run_under=//build_tools/ci:parallel_gpu_execute --//xla/tsl:ci_build --color=yes --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... bazel analyze-profile profile.json.gz # END BuildType.XLA_LINUX_X86_GPU_L4_GITHUB_ACTIONS # BEGIN BuildType.XLA_LINUX_X86_GPU_ONEAPI_GITHUB_ACTIONS diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 24a305d31a76ce..f173369799956f 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1248,7 +1248,6 @@ xla_test( backend_tags = { "gpu": [ "multi_gpu", - "no_oss", ], }, backends = ["gpu"], diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index b8f0ea10cad89b..aff7b7e1abfcdd 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -2916,10 +2916,6 @@ xla_test( "gpu": [ "multi_gpu", ], - "nvgpu_any": [ - "broken", - "no_oss", - ], }, backends = [ "gpu", @@ -2971,9 +2967,6 @@ xla_test( "gpu": [ "multi_gpu", ], - "nvgpu_any": [ - "no_oss", - ], }, backends = ["gpu"], deps = [ @@ -3016,10 +3009,6 @@ xla_test( "gpu": [ "multi_gpu", ], - "nvgpu_any": [ - "broken", - "no_oss", - ], }, backends = [ "gpu", @@ -3043,10 +3032,6 @@ xla_test( "gpu": [ "multi_gpu", ], - "nvgpu_any": [ - "broken", - "no_oss", - ], }, backends = [ "gpu", diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 1191df40032c41..872492ffecbdeb 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -2396,6 +2396,14 @@ class AllReduceTest /*memory_size=*/32 * kMB, /*collectives_memory_size=*/0) {} + void SetUp() override { + CollectiveOpsE2ETestBase::SetUp(); + if (!IsAmpereAndHigher()) { + GTEST_SKIP() << "Test requires Ampere or newer architecture since it's " + "using triton."; + } + } + protected: DebugOptions GetDebugOptionsForTest() const override { DebugOptions opts = CollectiveOpsWithFlagsBase::GetDebugOptionsForTest(); diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test_base.h b/third_party/xla/xla/tests/collective_ops_e2e_test_base.h index 93190cc0e7c85c..8cf62249adca78 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test_base.h +++ b/third_party/xla/xla/tests/collective_ops_e2e_test_base.h @@ -80,6 +80,11 @@ class CollectiveOpsE2ETestBase : public HloHardwareIndependentTestBase { Capability().cuda_compute_capability()->IsAtLeastHopper(); } + bool IsAmpereAndHigher() { + return Capability().IsCuda() && + Capability().cuda_compute_capability()->IsAtLeastAmpere(); + } + protected: std::unique_ptr hlo_runner_; std::unique_ptr reference_hlo_runner_; From a380e6c5e155b4ae40c28a4c4bdc13cb4b74b941 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 8 Dec 2025 01:38:49 -0800 Subject: [PATCH 015/753] Remove accidental const keyword. The alias_info_ member should not be const. PiperOrigin-RevId: 841652983 --- third_party/xla/xla/service/buffer_assignment_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 31a0cc1c4d5e32..7627b0afd76bab 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -399,7 +399,7 @@ class BufferAssignmentTest : public HloHardwareIndependentTestBase { Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10}); Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_}); Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_}); - const AliasInfo alias_info_; + AliasInfo alias_info_; }; // Returns true if the buffers assigned to instructions in "a" are distinct From 325a5b2649bcdf4522a1a58ea354268399f14488 Mon Sep 17 00:00:00 2001 From: Henning Becker Date: Mon, 8 Dec 2025 02:09:09 -0800 Subject: [PATCH 016/753] Store ThunkProto in GpuExecutable. This change modifies GpuExecutable to generate and store the ThunkProto during creation, before running thunk passes. The stored proto is then used when serializing the GpuExecutable to a proto, instead of generating it on demand after thunk passes ran. This is a temporary measure to make debug dumping of GPU executables possible. Long term we want to split GpuExecutable into 2 entities - one that is being produced by the compiler and doesn't depend on runtime facilities, and a second one which gets generated from the first one and has all the execution code. But this is unfortunately a bigger refactoring, therefore we need a quicker way. PiperOrigin-RevId: 841662771 --- .../xla/xla/service/gpu/gpu_executable.cc | 18 +++-- .../xla/xla/service/gpu/gpu_executable.h | 7 +- .../xla/service/gpu/gpu_executable_test.cc | 72 +++++++++++++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index d72fdeddab4fe3..df3767982ce5f9 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -88,6 +88,7 @@ limitations under the License. #include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/kernel_stats.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" @@ -238,6 +239,10 @@ absl::StatusOr> GpuExecutable::Create( GpuExecutableThunkPassBufferAllocator allocator(next_idx); + // TODO(b/461380690): Remove this once we have a better way to distinguish + // between compiler-generated and runtime-loaded GPU executables. + absl::StatusOr thunk_proto = params.executable->ToProto(); + TF_RETURN_IF_ERROR(RunThunkPasses( params.debug_options, params.device_description, params.executable.get(), params.debug_module.get(), allocator)); @@ -251,7 +256,7 @@ absl::StatusOr> GpuExecutable::Create( std::move(allocator.MutableAllocations()), std::move(params.alias_info), std::move(params.debug_options), std::move(params.constants), std::move(params.output_info), params.enable_debug_info_manager, - std::move(params.module_stats))); + std::move(params.module_stats), std::move(thunk_proto))); } // Implementation note: HLO profiling is always enabled for GPU executables, @@ -268,7 +273,8 @@ GpuExecutable::GpuExecutable( std::unique_ptr alias_info, DebugOptions debug_options, std::vector constants, absl::flat_hash_map output_info, - bool enable_debug_info_manager, ModuleStats module_stats) + bool enable_debug_info_manager, ModuleStats module_stats, + absl::StatusOr thunk_proto) : Executable(std::move(debug_module)), text_(std::move(asm_text)), binary_(std::move(binary)), @@ -288,7 +294,8 @@ GpuExecutable::GpuExecutable( debug_options.xla_debug_buffer_assignment_show_max()), constants_(std::move(constants)), output_info_(std::move(output_info)), - enable_debug_info_manager_(enable_debug_info_manager) { + enable_debug_info_manager_(enable_debug_info_manager), + thunk_proto_(std::move(thunk_proto)) { if (gpu_version_.IsRocm()) { // ROCm uses hsaco hashes to distinguish between modules. // Bad things happen if multiple modules with identical code are loaded. @@ -1230,7 +1237,10 @@ absl::StatusOr GpuExecutable::ToProto() const { *proto.mutable_gpu_compute_capability() = gpu_version_.ToProto(); - TF_ASSIGN_OR_RETURN(*proto.mutable_thunk(), thunks_->ToProto()); + // TODO(b/461380690): Generate the proto on-the-fly once we have a better way + // to distinguish between compiler-generated and runtime-loaded GPU + // executables. + TF_ASSIGN_OR_RETURN(*proto.mutable_thunk(), thunk_proto_); proto.set_module_name(module_name_); *proto.mutable_program_shape() = program_shape_.ToProto(); diff --git a/third_party/xla/xla/service/gpu/gpu_executable.h b/third_party/xla/xla/service/gpu/gpu_executable.h index ce1a5eff0bb591..867dbf2275fb4c 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.h +++ b/third_party/xla/xla/service/gpu/gpu_executable.h @@ -247,7 +247,8 @@ class GpuExecutable : public Executable { std::unique_ptr alias_info, DebugOptions debug_options, std::vector constants, absl::flat_hash_map output_info, - bool enable_debug_info_manager, ModuleStats module_stats); + bool enable_debug_info_manager, ModuleStats module_stats, + absl::StatusOr thunk_proto); // GpuExecutable check with either AMD's ISA version, or Nvidia's major minor // version for compute capability, depending on the hardware. @@ -369,6 +370,10 @@ class GpuExecutable : public Executable { GpuExecutable(const GpuExecutable&) = delete; GpuExecutable& operator=(const GpuExecutable&) = delete; + + // Stores the thunk graph as a proto from before running the thunk pass. + // Might contain an error if the given thunk graph is not serializable. + absl::StatusOr thunk_proto_; }; absl::StatusOr> diff --git a/third_party/xla/xla/service/gpu/gpu_executable_test.cc b/third_party/xla/xla/service/gpu/gpu_executable_test.cc index 1d5d68823b970f..33483843b616cc 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable_test.cc @@ -649,5 +649,77 @@ TEST(GpuExecutableTest, FromProtoWithSymbolResolver) { EXPECT_EQ(symbol_resolver_invocations, 1); } +TEST(GpuExecutableTest, ToProtoReturnsUnchangedThunkGraph) { + DebugOptions debug_options; + debug_options.set_xla_gpu_graph_min_graph_size(1); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + + auto create_executable = [&]() { + ThunkSequence thunk_sequence; + thunk_sequence.push_back(std::make_unique( + ThunkInfoWithId(1), + /*kernel_name=*/"test_kernel_0", + /*kernel_arguments=*/emitters::KernelArguments({}), + /*launch_dimensions=*/LaunchDimensions(), + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0, + /*tma_metadata=*/se::gpu::TmaMetadata())); + thunk_sequence.push_back(std::make_unique( + ThunkInfoWithId(2), + /*kernel_name=*/"test_kernel_1", + /*kernel_arguments=*/emitters::KernelArguments({}), + /*launch_dimensions=*/LaunchDimensions(), + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0, + /*tma_metadata=*/se::gpu::TmaMetadata())); + thunk_sequence.push_back(std::make_unique( + ThunkInfoWithId(3), + /*kernel_name=*/"test_kernel_2", + /*kernel_arguments=*/emitters::KernelArguments({}), + /*launch_dimensions=*/LaunchDimensions(), + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0, + /*tma_metadata=*/se::gpu::TmaMetadata())); + thunk_sequence.push_back(std::make_unique( + ThunkInfoWithId(4), + /*kernel_name=*/"test_kernel_3", + /*kernel_arguments=*/emitters::KernelArguments({}), + /*launch_dimensions=*/LaunchDimensions(), + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0, + /*tma_metadata=*/se::gpu::TmaMetadata())); + thunk_sequence.push_back(std::make_unique( + ThunkInfoWithId(5), + /*kernel_name=*/"test_kernel_4", + /*kernel_arguments=*/emitters::KernelArguments({}), + /*launch_dimensions=*/LaunchDimensions(), + /*cluster_dim=*/std::nullopt, + /*shmem_bytes=*/0, + /*tma_metadata=*/se::gpu::TmaMetadata())); + + GpuExecutable::Params params; + params.executable = std::make_unique( + ThunkInfoWithId(20), std::move(thunk_sequence)); + params.debug_options = debug_options; + + params.module_name = "test_module"; + return GpuExecutable::Create(std::move(params)); + }; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + create_executable()); + + // We expect our 5 kernel launches got wrapped in a command buffer thunk. + // If this assertion fails, you might need to either adjust the thunk graph or + // the debug options such that we do some kind of thunk graph transformation + // that we can test for. + ASSERT_THAT(executable->GetThunk().thunks(), SizeIs(1)); + + // The proto should be a straight dump of the thunk graph, without any + // transformation. + TF_ASSERT_OK_AND_ASSIGN(GpuExecutableProto proto, executable->ToProto()); + ASSERT_TRUE(proto.thunk().has_sequential_thunk()); + EXPECT_THAT(proto.thunk().sequential_thunk().thunks(), SizeIs(5)); +} + } // namespace } // namespace xla::gpu From 8ee821c813471aaea6a80f33a8bfd319aa90cee2 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 8 Dec 2025 03:19:12 -0800 Subject: [PATCH 017/753] [XLA:CPU] Use loop emitter rather than copy thunk for sub-byte types. Confirmed that the test fails before this change. PiperOrigin-RevId: 841683617 --- third_party/xla/xla/service/cpu/tests/BUILD | 16 ++++++ .../xla/service/cpu/tests/cpu_copy_test.cc | 54 +++++++++++++++++++ .../xla/xla/service/cpu/thunk_emitter.cc | 5 +- 3 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/cpu/tests/cpu_copy_test.cc diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index 3ea7a53b3b6206..dd27f5eb8a38e9 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -442,3 +442,19 @@ xla_cc_test( "@local_tsl//tsl/platform:platform_port", ], ) + +xla_cc_test( + name = "cpu_copy_test", + srcs = ["cpu_copy_test.cc"], + deps = [ + ":cpu_codegen_test_main", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) diff --git a/third_party/xla/xla/service/cpu/tests/cpu_copy_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_copy_test.cc new file mode 100644 index 00000000000000..20de31fc7be8fd --- /dev/null +++ b/third_party/xla/xla/service/cpu/tests/cpu_copy_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/service/cpu/tests/cpu_codegen_test.h" +#include "xla/tsl/platform/statusor.h" + +namespace xla::cpu { +namespace { + +TEST_F(CpuCodegenTest, SubByteCopy) { + const std::string hlo_text = R"hlo( +HloModule module + +ENTRY entry { + in = u2[20,20]{1,0:E(2)} iota(), iota_dimension=1 + transpose = u2[20,20]{0,1:E(2)} transpose(in), dimensions={1,0} + copy = u2[20,20]{1,0:E(2)} copy(transpose) + ROOT out = u8[20,20]{1,0} convert(copy) +} +)hlo"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + const Literal result, + Execute(std::move(module), {}, /*run_hlo_passes=*/false)); + + absl::Span result_data = result.data(); + for (int64_t row = 0; row < 20; ++row) { + for (int64_t col = 0; col < 20; ++col) { + EXPECT_EQ(result_data[row * 20 + col], row % 4); + } + } +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 3f9932fc78bc43..0506620c3bb7a6 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -486,7 +486,10 @@ absl::StatusOr ThunkEmitter::EmitHloInstruction( return EmitConvolutionThunk(instruction); case HloOpcode::kCopy: { - if (options_.compile_copy_as_llvm_kernel) { + // The copy thunk does not support sub-byte data types. + bool has_byte_strides = + ShapeUtil::ByteStrides(instruction->shape()).has_value(); + if (!has_byte_strides || options_.compile_copy_as_llvm_kernel) { return EmitElementalKernelThunk(instruction); } return EmitCopyThunk(instruction); From 7c95198b02af894ca3177991c90ea42fe904ed02 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Mon, 8 Dec 2025 03:35:20 -0800 Subject: [PATCH 018/753] [XLA:GPU] Return early in `CalculateBitcastOfTransposeImpl` if indices are empty. This avoids hitting an assert later. This is a stop gap solution until support for size-1 dims in bitcasts has been added. PiperOrigin-RevId: 841687992 --- .../xla/xla/service/gpu/transforms/nest_gemm_fusion.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 0a8633fd5e7b4d..3383d936a0c931 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -679,6 +679,12 @@ absl::StatusOr CalculateBitcastOfTransposeImpl( indices.push_back(index); }; + if (indices.empty()) { + return absl::InvalidArgumentError( + absl::StrCat("Cannot hoist bitcast across ", transpose->ToString(), + " because size-1 dims in bitcasts are not yet supported " + "(b/466065483).")); + } if (indices.back() - indices.front() >= transpose_to - transpose_from || !absl::c_is_sorted(indices)) { return absl::InvalidArgumentError( From 144a6cb45fcdcc39f743825ea3ebcbff2a24d7c5 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 8 Dec 2025 03:45:30 -0800 Subject: [PATCH 019/753] Make use of inserted_window_dims attribute of scatter. This makes the logic a bit easier. Right now, ScatterSimplifier will turn this into the original expanded update shape with 1-sized update window dimensions. But once our scatter emitter can handle it, we might avoid the reshape. PiperOrigin-RevId: 841690525 --- .../expanders/permutation_sort_expander.cc | 11 ++--------- .../expanders/permutation_sort_expander_test.cc | 13 ++++++++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander.cc index 5af2a3e0056679..cfb796aef6c898 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander.cc @@ -127,13 +127,6 @@ absl::StatusOr PermutationSortExpander::ExpandInstruction( instruction->AddInstruction(HloInstruction::CreateBroadcast( update_shape, zero, /*broadcast_dimensions=*/{})); - // Construct the updates operand of scatter. - for (int64_t i = 0; i < rank; ++i) { - ShapeUtil::AppendMinorDimension(1, &update_shape); - } - HloInstruction* scatter_updates = instruction->AddInstruction( - HloInstruction::CreateReshape(update_shape, values)); - // Construct the updates computation, which simply replaces the operand // values with the update values. HloComputation::Builder b("update_replace_computation"); @@ -149,12 +142,12 @@ absl::StatusOr PermutationSortExpander::ExpandInstruction( ScatterDimensionNumbers dim_numbers; dim_numbers.set_index_vector_dim(rank); for (int64_t i = 0; i < rank; ++i) { - dim_numbers.add_update_window_dims(rank + i); + dim_numbers.add_inserted_window_dims(i); dim_numbers.add_scatter_dims_to_operand_dims(i); } HloInstruction* scatter = instruction->AddInstruction(HloInstruction::CreateScatter( - values->shape(), scatter_operand, scatter_indices, scatter_updates, + update_shape, scatter_operand, scatter_indices, values, update_replace_computation, dim_numbers, /*indices_are_sorted=*/false, /*unique_indices=*/true)); return instruction->AddInstruction(HloInstruction::CreateTuple( diff --git a/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander_test.cc index 0351221da62b0d..329df8c95e9bd6 100644 --- a/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/permutation_sort_expander_test.cc @@ -63,11 +63,14 @@ TEST_F(PermutationSortExpanderTest, ReplacePermutationSortWithScatter) { EXPECT_THAT(PermutationSortExpander().Run(module.get()), IsOkAndHolds(true)); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, - op::Tuple(op::Iota(), - op::Scatter(op::Broadcast(op::Constant()), - op::Concatenate(op::Iota(), op::Reshape()), - op::Reshape()))); + EXPECT_THAT( + root, op::Tuple(op::Iota(), + op::Scatter( + op::Broadcast(op::Constant()), + op::Concatenate(op::Iota(), + op::Reshape(op::GetTupleElement( + op::Sort(), /*tuple_index=*/1))), + op::Iota()))); } TEST_F(PermutationSortExpanderTest, DontReplaceIfWrongComparisonDirection) { From a693e6d76f42ee719cc623b60d6b2e9fb17c9f1e Mon Sep 17 00:00:00 2001 From: Shanbin Ke Date: Mon, 8 Dec 2025 04:05:54 -0800 Subject: [PATCH 020/753] PR #34789: [XLA:GPU] Fix cuDNN SDPA test to use 0 as workspace size to work universally on all archs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34789 📝 Summary of Changes use 0 as default workspace size and query later so it works universally on all archs, cuDNN paged attention reference doesn't do this like other cuDNN sdpa tests, it fails on B200 in NV internal CI. Therefore the fix. 🎯 Justification use 0 as default workspace size and query later so it works universally on all archs, cuDNN paged attention reference doesn't do this like other cuDNN sdpa tests, it fails on B200 in NV internal CI. Therefore the fix. 🚀 Kind of Contribution 🐛 Bug Fix 📊 Benchmark (for Performance Improvements) None 🧪 Unit Tests: None 🧪 Execution Tests: None Copybara import of the project: -- 7c53e935fcb424970da1ffed4c18a95e08835d57 by Cjkkkk : use 0 as workspace to work universally on all arch Merging this change closes #34789 PiperOrigin-RevId: 841696508 --- third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index 3e984387001e6f..8a0329b6a1fbf6 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1372,7 +1372,7 @@ class FlashAttentionPagedAttention : public MultiHeadedAttentionTest { ENTRY %main.7 (Arg_0.1: bf16[1,128,2,128], Arg_1.2: bf16[1,128,2,128]) -> bf16[1,128,2,128] { %Arg_1.2 = bf16[1,128,2,128]{3,2,1,0} parameter(1) %Arg_0.1 = bf16[1,128,2,128]{3,2,1,0} parameter(0) - %custom-call.3 = (bf16[1,2,128,128]{3,1,2,0}, u8[256]{0}) custom-call(%Arg_0.1, %Arg_1.2, %Arg_1.2), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[1,128,2,128]{3,2,1,0}, bf16[1,128,2,128]{3,2,1,0}, bf16[1,128,2,128]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "workspace_size": "0"}, "fmha_scale": 1.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["1", "2", "128", "128"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1, "is_paged_attention": false}} + %custom-call.3 = (bf16[1,2,128,128]{3,1,2,0}, u8[0]{0}) custom-call(%Arg_0.1, %Arg_1.2, %Arg_1.2), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[1,128,2,128]{3,2,1,0}, bf16[1,128,2,128]{3,2,1,0}, bf16[1,128,2,128]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "workspace_size": "0"}, "fmha_scale": 1.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["1", "2", "128", "128"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1, "is_paged_attention": false}} %get-tuple-element.4.0 = bf16[1,2,128,128]{3,1,2,0} get-tuple-element(%custom-call.3), index=0 ROOT %bitcast.6.0 = bf16[1,128,2,128]{3,2,1,0} bitcast(%get-tuple-element.4.0) } From f745573a52d8348648048d42817c66579536338c Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 8 Dec 2025 04:07:20 -0800 Subject: [PATCH 021/753] [XLA:CPU/GPU][XTile] Fix not instruction for non-pred types. PiperOrigin-RevId: 841696951 --- .../xla/xla/backends/gpu/codegen/triton/BUILD | 1 + .../gpu/codegen/triton/emitter_helpers.cc | 14 +++++++++- .../gpu/codegen/triton/emitter_helpers.h | 5 +--- .../triton/fusion_emitter_device_test.cc | 26 +++++++++++++++++++ 4 files changed, 41 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 19587ed10c4d4a..ce2212b64621e9 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -132,6 +132,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Support", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index c81bfdb35696c5..f73c72bcf7873a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" @@ -156,6 +157,17 @@ absl::StatusOr EmitNestedFusion( return EmitScope(b, to_emit, region_values); } + +// Get a constant with all high bits of the same type as provided. +mlir::Value OnesLike(mlir::ImplicitLocOpBuilder& b, mlir::Type type) { + mlir::Type element_type = mlir::getElementTypeOrSelf(type); + CHECK(element_type.isInteger()) << "OnesLike only supports integer types."; + + int64_t width = element_type.getIntOrFloatBitWidth(); + mlir::APInt all_ones = mlir::APInt::getAllOnes(width); + return mlir::createScalarOrSplatConstant(b, b.getLoc(), type, all_ones); +} + } // namespace SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { @@ -425,7 +437,7 @@ absl::StatusOr EmitElementwise(mlir::ImplicitLocOpBuilder& b, case HloOpcode::kFloor: return mm::FloorOp::create(b, inputs[0]); case HloOpcode::kNot: - return ma::XOrIOp::create(b, inputs[0], OnesLike(b, inputs[0])); + return ma::XOrIOp::create(b, inputs[0], OnesLike(b, inputs[0].getType())); case HloOpcode::kNegate: // NegFOp is not supported by Triton. return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h index 89ed1ef978bb52..5d1dfee338123f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.h @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" @@ -197,10 +198,6 @@ inline mlir::Value ZerosLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { return ConstLike(b, x, 0); } -inline mlir::Value OnesLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { - return ConstLike(b, x, 1); -} - bool IsFp8Type(mlir::Type t); // Triton type conversions. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 07605fb6d34987..99112cd3bf51b3 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -420,6 +420,32 @@ CHECK: arith.divsi {{.*}} : i32 EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); } +TEST_F(TritonEmitterTest, BitwiseNotIsEmittedCorrectly) { + constexpr absl::string_view kHloText = R"( +HloModule m + +fused_not { + param_0 = s32[100] parameter(0) + ROOT not = s32[100] not(param_0) +} + +ENTRY main { + p0 = s32[100] parameter(0) + ROOT not = s32[100] fusion(p0), kind=kCustom, calls=fused_not, + backend_config={"fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{ + "num_warps":"1","output_tiles":[{"sizes":[100]}], + "num_ctas":1,"num_stages":1,"is_tma_allowed":false}}} +} +)"; + TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "fused_not", R"( +CHECK: arith.constant dense<-1> +CHECK: arith.xori +)")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, kExactMatch)); +} + TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { constexpr absl::string_view kHloText = R"( HloModule m From de1e81c3c327409d551af8973ff5eaadf4acf9a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eusebio=20Dur=C3=A1n=20Monta=C3=B1a?= Date: Mon, 8 Dec 2025 04:08:12 -0800 Subject: [PATCH 022/753] Add missing BUILD dependencies, and remove unused ones PiperOrigin-RevId: 841697193 --- third_party/xla/xla/backends/cpu/BUILD | 1 - third_party/xla/xla/backends/cpu/autotuner/BUILD | 3 --- .../backends/cpu/codegen/emitters/transforms/BUILD | 2 -- third_party/xla/xla/backends/cpu/runtime/BUILD | 2 -- third_party/xla/xla/backends/gpu/codegen/llvm/BUILD | 1 - third_party/xla/xla/backends/gpu/collectives/BUILD | 3 --- third_party/xla/xla/backends/gpu/runtime/BUILD | 9 --------- .../xla/xla/backends/profiler/subprocess/BUILD | 4 ---- .../xla/xla/codegen/emitters/transforms/BUILD | 2 -- third_party/xla/xla/codegen/xtile/ir/BUILD | 1 - third_party/xla/xla/pjrt/distributed/BUILD | 13 ------------- third_party/xla/xla/service/BUILD | 1 - third_party/xla/xla/service/gpu/BUILD | 1 - third_party/xla/xla/stream_executor/BUILD | 3 --- third_party/xla/xla/stream_executor/cuda/BUILD | 1 - third_party/xla/xla/stream_executor/gpu/BUILD | 1 - third_party/xla/xla/tsl/framework/BUILD | 1 - third_party/xla/xla/tsl/profiler/rpc/BUILD | 6 ------ third_party/xla/xla/util/BUILD | 1 - 19 files changed, 56 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/BUILD b/third_party/xla/xla/backends/cpu/BUILD index 7d9e7ad114aba0..05df8d4e5fd66d 100644 --- a/third_party/xla/xla/backends/cpu/BUILD +++ b/third_party/xla/xla/backends/cpu/BUILD @@ -67,7 +67,6 @@ cc_library( "//xla/ffi", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", - "@com_google_absl//absl/base:core_headers", ], ) diff --git a/third_party/xla/xla/backends/cpu/autotuner/BUILD b/third_party/xla/xla/backends/cpu/autotuner/BUILD index e5bc7335a1a082..16640e22a3a8f5 100644 --- a/third_party/xla/xla/backends/cpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/cpu/autotuner/BUILD @@ -24,13 +24,10 @@ cc_library( "//xla/service:compiler", "//xla/service:executable", "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/platform:errors", - "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], diff --git a/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD b/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD index 1f74e57abe45a5..123e0394dcd297 100644 --- a/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/emitters/transforms/BUILD @@ -61,9 +61,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:UBDialect", "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:VectorUtils", ], ) diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 8d68edd2d6e12e..e1e85a5cd4675a 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -986,8 +986,6 @@ xla_cc_test( "//xla:literal_util", "//xla/runtime:work_group", "//xla/service:buffer_assignment", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test", diff --git a/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD b/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD index 43cd45f395a9cc..4f7e246607e1c0 100644 --- a/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/llvm/BUILD @@ -86,7 +86,6 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/tsl/platform:errors", - "//xla/tsl/platform:status", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/backends/gpu/collectives/BUILD b/third_party/xla/xla/backends/gpu/collectives/BUILD index 975166a9b64b34..6859a56e88c0fa 100644 --- a/third_party/xla/xla/backends/gpu/collectives/BUILD +++ b/third_party/xla/xla/backends/gpu/collectives/BUILD @@ -150,7 +150,6 @@ cc_library( "//xla/service:rendezvous", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/log", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -166,11 +165,9 @@ xla_cc_test( ":gpu_clique_rendezvous", "//xla/core/collectives:rank_id", "//xla/runtime:device_id", - "//xla/service:rendezvous", "//xla/tsl/platform:env", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", - "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", ], ) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 192c99de9902a6..2c79ce55668f95 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -87,7 +87,6 @@ cc_library( "//xla:types", "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/ffi:attribute_map", "//xla/ffi:call_frame", @@ -1723,7 +1722,6 @@ cc_library( "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", - "//xla/core/collectives:rank_id", "//xla/runtime:device_id", "//xla/service:collective_ops_utils", "//xla/tsl/platform:statusor", @@ -1739,7 +1737,6 @@ cc_library( srcs = ["collective_thunk.cc"], hdrs = ["collective_thunk.h"], deps = [ - ":collective_cliques", ":collective_execution", ":collective_params", ":thunk", @@ -1751,14 +1748,12 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/gpu/collectives:gpu_clique_key", - "//xla/backends/gpu/collectives:gpu_collectives", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/hlo/ir:collective_op_group_mode", "//xla/hlo/ir:hlo", "//xla/runtime:device_id", "//xla/service:buffer_assignment", - "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:rendezvous", "//xla/service/gpu:buffer_allocations", @@ -2025,8 +2020,6 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:collective_kernel_metadata", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:multicast_memory", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", @@ -2151,7 +2144,6 @@ cc_library( "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", "//xla/runtime:buffer_use", - "//xla/runtime:device_id", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service/gpu:backend_configs_cc", @@ -3608,7 +3600,6 @@ cc_library( "//xla/stream_executor:device_address", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_args", - "//xla/stream_executor:kernel_argument_packing_spec", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:statusor", diff --git a/third_party/xla/xla/backends/profiler/subprocess/BUILD b/third_party/xla/xla/backends/profiler/subprocess/BUILD index 1a62071beafea5..4ca882ae468219 100644 --- a/third_party/xla/xla/backends/profiler/subprocess/BUILD +++ b/third_party/xla/xla/backends/profiler/subprocess/BUILD @@ -59,7 +59,6 @@ cc_library( "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -87,7 +86,6 @@ xla_cc_test( ":subprocess_registry", "//xla/backends/profiler:profiler_backends", # buildcleaner: keep "//xla/tsl/platform:env", - "//xla/tsl/platform:resource_loader", "//xla/tsl/platform:subprocess", "//xla/tsl/platform:test", "//xla/tsl/platform:test_main", @@ -100,8 +98,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/lib:traceme", diff --git a/third_party/xla/xla/codegen/emitters/transforms/BUILD b/third_party/xla/xla/codegen/emitters/transforms/BUILD index 429feecc9eb982..4f910f30935a01 100644 --- a/third_party/xla/xla/codegen/emitters/transforms/BUILD +++ b/third_party/xla/xla/codegen/emitters/transforms/BUILD @@ -82,7 +82,6 @@ cc_library( "//xla/codegen/intrinsic:tanh", "//xla/codegen/intrinsic:type", "//xla/hlo/analysis:indexing_analysis", - "//xla/hlo/analysis:symbolic_expr", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", "//xla/service/gpu:ir_emission_utils", @@ -91,7 +90,6 @@ cc_library( "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/rocm:rocm_compute_capability", "//xla/tsl/platform:logging", - "//xla/tsl/platform:status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/codegen/xtile/ir/BUILD b/third_party/xla/xla/codegen/xtile/ir/BUILD index de986615da6c0e..58f8c97439da40 100644 --- a/third_party/xla/xla/codegen/xtile/ir/BUILD +++ b/third_party/xla/xla/codegen/xtile/ir/BUILD @@ -105,7 +105,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InliningUtils", - "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 710dfc9f0208a5..4f5e5356f6a343 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -20,8 +20,6 @@ cc_library( srcs = ["service.cc"], hdrs = ["service.h"], deps = [ - ":topology_util", - ":util", "//xla:types", "//xla:util", "//xla/tsl/distributed_runtime/coordination:coordination_service", @@ -31,14 +29,11 @@ cc_library( "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:random", ], ) @@ -70,7 +65,6 @@ cc_library( ], deps = [ ":key_value_store_interface", - ":util", "//xla/runtime:device_id", "//xla/tsl/distributed_runtime/coordination:coordination_client", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", @@ -87,8 +81,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:statusor", ], ) @@ -153,7 +145,6 @@ xla_cc_test( ":topology_util", "//xla:status_macros", "//xla/runtime:device_id", - "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", @@ -170,10 +161,6 @@ xla_cc_test( "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index fb4b60aedea3f1..ddf008d75b3c80 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3973,7 +3973,6 @@ cc_library( "//xla:literal", "//xla:shape_util", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings:string_view", ], ) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 038abc6f3c5dd6..98e1675e56ab9a 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3265,7 +3265,6 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", - "//xla/service:collective_permute_decomposer", "//xla/service:hlo_cost_analysis", "//xla/service:latency_hiding_scheduler", "//xla/service:profile_guided_latency_estimator", diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index f00f783257bce6..195ec82c4d2777 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -77,7 +77,6 @@ cc_library( name = "device_address", hdrs = ["device_address.h"], deps = [ - "//xla/tsl/platform:logging", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", ], @@ -716,8 +715,6 @@ cc_library( name = "kernel", hdrs = ["kernel.h"], deps = [ - ":device_address", - ":device_memory", ":kernel_args", ":kernel_metadata", ":launch_dim", diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index f173369799956f..ba0403fbed0832 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -1035,7 +1035,6 @@ xla_cc_test( "notsan", ], deps = [ - ":compilation_provider", ":cuda_compute_capability", ":nvjitlink", ":nvjitlink_support", diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index ddabc85bce9692..52e926f2befb72 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -171,7 +171,6 @@ cc_library( hdrs = ["gpu_executor.h"], deps = [ ":multicast_memory", - "//xla/stream_executor:device_address", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_common", "//xla/stream_executor:stream_executor_h", diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 6983cd1250f1ba..2c058abf55f2dd 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -413,7 +413,6 @@ tsl_cc_test( ":cancellation", "//xla/tsl/platform:env", "//xla/tsl/platform:env_impl", # buildcleaner: keep - "//xla/tsl/platform:status", "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/tsl/profiler/rpc/BUILD b/third_party/xla/xla/tsl/profiler/rpc/BUILD index 523db019d51d6f..fa081c98557a7e 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/BUILD @@ -35,21 +35,15 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", "//xla/tsl/platform:macros", - "//xla/tsl/platform:status", - "//xla/tsl/platform:types", "//xla/tsl/profiler/rpc/client:save_profile", - "//xla/tsl/profiler/utils:file_system_utils", "//xla/tsl/profiler/utils:math_utils", "//xla/tsl/profiler/utils:profiler_options_util", "//xla/tsl/profiler/utils:time_utils", "//xla/tsl/profiler/utils:xplane_utils", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/profiler/lib:profiler_session", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", diff --git a/third_party/xla/xla/util/BUILD b/third_party/xla/xla/util/BUILD index 3e503bd2dded2e..a83c965f8c16af 100644 --- a/third_party/xla/xla/util/BUILD +++ b/third_party/xla/xla/util/BUILD @@ -28,7 +28,6 @@ cc_library( deps = [ "//xla:util", "//xla:xla_data_proto_cc", - "//xla/tsl/platform:statusor", "@com_google_absl//absl/status:statusor", "@dlpack", ], From 0dc10cd4e296a41f229a8b9196e127c7576ed30f Mon Sep 17 00:00:00 2001 From: Chenhao Jiang Date: Mon, 8 Dec 2025 04:10:46 -0800 Subject: [PATCH 023/753] PR #34917: Turn on the scatter determinism expander by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34917 📝 Summary of Changes Enable xla_gpu_enable_scatter_determinism_expander flag by default (change from false to true). 🎯 Justification The scatter determinism expander provides significant performance improvements for deterministic scatter operations (up to 9000x speedup for certain input sizes compared to the sequential while-loop approach). With recent fixes for batched scatter support and proper handling of scatter_dims_to_operand_dims, the pass is now robust enough to be enabled by default. Users who experience issues can still disable it with --xla_gpu_enable_scatter_determinism_expander=false. 🚀 Kind of Contribution ⚡️ Performance Improvement 🧪 Unit Tests All existing scatter tests pass with the flag enabled by default: //xla/service:scatter_determinism_expander_test //xla/tests:scatter_test Copybara import of the project: -- 0bb296398991b3de5a4d15f45fd4e80f52880852 by Chenhao Jiang : Turn on the scatter determinism expander by default Merging this change closes #34917 PiperOrigin-RevId: 841697954 --- third_party/xla/xla/debug_options_flags.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index c435337dda368a..fe8c14f18dbd8c 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -441,7 +441,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_fast_math(false); opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); opts.set_xla_pjrt_allow_auto_layout_in_hlo(false); - opts.set_xla_gpu_enable_scatter_determinism_expander(false); + opts.set_xla_gpu_enable_scatter_determinism_expander(true); opts.set_xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(false); opts.set_xla_gpu_unsupported_use_all_reduce_one_shot_kernel(false); opts.set_xla_gpu_unsupported_use_ragged_all_to_all_one_shot_kernel(true); From bb8f2561f9e137b93a405525b888687971748611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eetu=20Sj=C3=B6blom?= Date: Mon, 8 Dec 2025 04:12:13 -0800 Subject: [PATCH 024/753] PR #34956: [ROCm] flush rocprofiler buffer when disabling RocmTracer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34956 Fixes a bug in RocmTracer where events not reaching the rocprofiler watermark are not captured when disabling the tracer. Happens e.g. when the workload is very small. Added an explicit buffer flush in and a relevant test that fails without the flush: `//xla/backends/profiler/gpu:rocm_tracer_test` 🚀 Kind of Contribution 🐛 Bug Fix Copybara import of the project: -- 7d27ae5615c5dd1ba244e6b55b16200ff7f45d2c by Eetu Sjöblom : flush rocprofiler buffer when disabling RocmTracer Merging this change closes #34956 PiperOrigin-RevId: 841698424 --- .../xla/xla/backends/profiler/gpu/BUILD | 1 + .../xla/backends/profiler/gpu/rocm_tracer.cc | 4 ++ .../backends/profiler/gpu/rocm_tracer_test.cc | 69 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index fefd3b9b992862..1a559898f65e3e 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -533,6 +533,7 @@ xla_cc_test( "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@local_config_rocm//rocm:hip", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc index 40f0e0e96cfbe9..a15f2e4bb690d1 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc @@ -499,6 +499,10 @@ void RocmTracer::toolFinalize(void* tool_data) { } void RocmTracer::Disable() { + rocprofiler_status_t status = rocprofiler_flush_buffer(buffer_); + if (status != ROCPROFILER_STATUS_SUCCESS) { + LOG(WARNING) << "rocprofiler_flush_buffer failed with error " << status; + } absl::MutexLock lock(collector_mutex_); collector_->Flush(); collector_ = nullptr; diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer_test.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer_test.cc index d8ad1392738d20..d03bb15dc80527 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer_test.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include #include +#include #include #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "rocm/include/hip/hip_runtime.h" #include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/backends/profiler/gpu/rocm_tracer_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" @@ -124,6 +126,73 @@ TEST(RocmTracerTest, AnnotationMapWorks) { EXPECT_EQ(result, annotation); } +// Simple collector that tracks received events for verification. +class EventCapturingCollector : public RocmTraceCollector { + public: + EventCapturingCollector() : RocmTraceCollector(MakeCollectorOptions()) {} + + void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) override { + event_count_++; + } + + void OnEventsDropped(const std::string& reason, + uint32_t num_events) override {} + void Flush() override {} + void Export(tsl::profiler::XSpace* space) override {} + + int event_count() const { return event_count_; } + + private: + static RocmTraceCollectorOptions MakeCollectorOptions() { + RocmTraceCollectorOptions options; + options.max_callback_api_events = 2 * 1024 * 1024; + options.max_activity_api_events = 2 * 1024 * 1024; + options.max_annotation_strings = 1024 * 1024; + options.num_gpus = RocmTracer::GetRocmTracerSingleton().NumGpus(); + return options; + } + int event_count_ = 0; +}; + +std::unique_ptr CreateEventCapturingCollector() { + return std::make_unique(); +} + +TEST(RocmTracerTest, CapturesHipEvents) { +#define HIP_ASSERT_OK(expr) ASSERT_EQ((expr), hipSuccess) << #expr " failed" + + int device_count = 0; + HIP_ASSERT_OK(hipGetDeviceCount(&device_count)); + ASSERT_GT(device_count, 0) << "No HIP devices available"; + + auto collector = CreateEventCapturingCollector(); + EventCapturingCollector* collector_ptr = collector.get(); + + RocmTracer& tracer = RocmTracer::GetRocmTracerSingleton(); + RocmTracerOptions tracer_options{/*max_annotation_strings=*/1024 * 1024}; + tracer.Enable(tracer_options, collector.get()); + + constexpr size_t kNumFloats = 1024; + constexpr size_t kSize = kNumFloats * sizeof(float); + std::vector host_data(kNumFloats, 1.0f); + void* device_data = nullptr; + + HIP_ASSERT_OK(hipMalloc(&device_data, kSize)); + HIP_ASSERT_OK( + hipMemcpy(device_data, host_data.data(), kSize, hipMemcpyHostToDevice)); + HIP_ASSERT_OK( + hipMemcpy(host_data.data(), device_data, kSize, hipMemcpyDeviceToHost)); + HIP_ASSERT_OK(hipDeviceSynchronize()); + + tracer.Disable(); + hipFree(device_data); + +#undef HIP_ASSERT_OK + + EXPECT_GT(collector_ptr->event_count(), 0) + << "Expected to capture at least one trace event"; +} + } // namespace } // namespace profiler } // namespace xla From c768540f431880e29c481a2609b31c2a70c69b1e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 06:28:49 -0800 Subject: [PATCH 025/753] Automated Code Change PiperOrigin-RevId: 841737217 --- third_party/xla/xla/hlo/pass/BUILD | 2 ++ third_party/xla/xla/hlo/pass/hlo_pass_interface.cc | 2 ++ third_party/xla/xla/hlo/pass/hlo_pass_interface.h | 1 + 3 files changed, 5 insertions(+) diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index 5a64c36e4596ad..4dff3438a1cd05 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -38,6 +38,7 @@ cc_library( "//xla/hlo/ir:hlo_module_group", "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", @@ -118,6 +119,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_interface.cc b/third_party/xla/xla/hlo/pass/hlo_pass_interface.cc index bec1de8aaaa219..be0eb44c285037 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_interface.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_interface.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_interface.h" +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_interface.h b/third_party/xla/xla/hlo/pass/hlo_pass_interface.h index cfbe9723201e1f..fb43ac39280e8a 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_interface.h +++ b/third_party/xla/xla/hlo/pass/hlo_pass_interface.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/base/attributes.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" From 181d057de8d1327e88f3531a99bf08fd0e0f6373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eusebio=20Dur=C3=A1n=20Monta=C3=B1a?= Date: Mon, 8 Dec 2025 06:30:50 -0800 Subject: [PATCH 026/753] Set up internal presubmit for unused/extra dependencies. PiperOrigin-RevId: 841737824 --- third_party/xla/xla/BUILD | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 151c46c8df3408..d07d329e7c2dca 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1,6 +1,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("//third_party/compute_library:build_defs.bzl", "if_enable_acl") +# copybara:uncomment load("//devtools/build_cleaner/skylark:action_config_test.bzl", "action_config_test") # copybara:uncomment load("@rules_python//python:proto.bzl", "py_proto_library") load("//xla:package_groups.bzl", "xla_package_groups") load("//xla:xla.default.bzl", "xla_bzl_library", "xla_cc_test", "xla_py_proto_library") @@ -1286,7 +1287,6 @@ xla_cc_test( "//xla/tsl/platform:env", "//xla/tsl/platform:test", "//xla/tsl/util:command_line_flags", - "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", @@ -1408,6 +1408,11 @@ cc_library( # visibility = internal_visibility([":friends"]), # deps = [":xla_proto"], # ) +# +# action_config_test( +# name = "build_cleaner_spec_test", +# src = "build_cleaner_spec.textproto", +# ) # copybara:uncomment_end cc_library( From 9981c805dcf2114836c8ea5eae4dc589ff4855fc Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 8 Dec 2025 06:36:09 -0800 Subject: [PATCH 027/753] [XLA:GPU] Improve error messages when we fail to tile a fusion. PiperOrigin-RevId: 841739377 --- .../xla/xla/codegen/tiling/symbolic_tile_analysis.cc | 4 ++-- .../xla/xla/service/gpu/transforms/nest_gemm_fusion.cc | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/third_party/xla/xla/codegen/tiling/symbolic_tile_analysis.cc b/third_party/xla/xla/codegen/tiling/symbolic_tile_analysis.cc index af166ae57df5cf..0ac831f03114a3 100644 --- a/third_party/xla/xla/codegen/tiling/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/codegen/tiling/symbolic_tile_analysis.cc @@ -459,8 +459,8 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( SymbolicTile::FromIndexingMap(reshape_indexing_map); if (!reshape_symbolic_tile.has_value()) { - return FusionDecision::Forbid("Bailing out on reshape ") - << hlo->ToString() << " with indexing map " + return FusionDecision::Forbid("Bailing out on reshape") + << " " << hlo->ToString() << " with indexing map " << ToString(reshape_indexing_map); } } diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index 3383d936a0c931..dc972406913fc3 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -1248,13 +1248,13 @@ absl::StatusOr FindBlockLevelParameters( SymbolicTileAnalysis::AnalyzeComputation( *computation, ctx, TritonEmitterConstraints::GetBuilder(device_description)); - if (std::holds_alternative(analysis_or)) { + + if (const auto* fusion_decision = std::get_if(&analysis_or)) { std::unique_ptr extracted_computation_module = ExtractInstructionIntoNewModule(*computation->FusionInstruction()); - return absl::InternalError( - absl::StrCat("Failed to analyze the computation (", - std::get(analysis_or).Explain(), - "): ", extracted_computation_module->ToString())); + return absl::InternalError(absl::StrCat( + "Failed to analyze the computation (", fusion_decision->Explain(), + "):\n", extracted_computation_module->ToString())); } auto& analysis = std::get(analysis_or); From 9baba425e7bfdd4b20ff35a8526abdf9488fdbba Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Mon, 8 Dec 2025 07:14:46 -0800 Subject: [PATCH 028/753] [XLA] Print backend config in HloPrintOptions::ShortParsable() The backend config carries semantic information. While "ShortParsable" is intended to be compact, it should be semantically equivalent to the default print style. PiperOrigin-RevId: 841750733 --- third_party/xla/xla/hlo/ir/hlo_print_options.h | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_print_options.h b/third_party/xla/xla/hlo/ir/hlo_print_options.h index cd4a72bb176dd8..87eca03d396113 100644 --- a/third_party/xla/xla/hlo/ir/hlo_print_options.h +++ b/third_party/xla/xla/hlo/ir/hlo_print_options.h @@ -96,7 +96,6 @@ class HloPrintOptions { .set_print_large_constants(true) .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) - .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_operand_index_annotation_interval(0) .set_print_program_shape(false) From 0df3891980edf719f6ff2b3b6f806db84aa3812f Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Mon, 8 Dec 2025 10:53:08 -0500 Subject: [PATCH 029/753] Update default `dtype` description in `ragged_factory_ops.py` Clarified default `dtype` for `RaggedTensor` when `pylist` is empty. Fixes https://github.com/tensorflow/tensorflow/issues/105858. --- tensorflow/python/ops/ragged/ragged_factory_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 55505df533d447..a21d85eca16fb5 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -61,7 +61,8 @@ def constant( compatible with `dtype`. dtype: The type of elements for the returned `RaggedTensor`. If not specified, then a default is chosen based on the scalar values in - `pylist`. + `pylist`. If there are no scalar values in `pylist`, then the default + is `tf.float32`. ragged_rank: An integer specifying the ragged rank of the returned `RaggedTensor`. Must be nonnegative and less than `K`. Defaults to `max(0, K - 1)` if `inner_shape` is not specified. Defaults to From cc26e9a554f13dc82985f52ea5f41320389884ca Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Mon, 8 Dec 2025 07:55:34 -0800 Subject: [PATCH 030/753] [XLA:GPU] Keep explanation and location separate in FusionDecision. Current logic didn't work nicely for streaming operator (<<), because it was concatenating a new location at every call and resulted in unreadable error message. This change also adds a SourceLocationHolder to limit `#if defined(PLATFORM_GOOGLE)` usage. PiperOrigin-RevId: 841764270 --- third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/instruction_fusion.cc | 12 --- .../xla/xla/service/instruction_fusion.h | 90 +++++++++---------- 3 files changed, 44 insertions(+), 59 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index ddf008d75b3c80..b8b6a62ec06674 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1977,6 +1977,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "//xla/tsl/platform:macros", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index 1d9b568f88f9f0..f4e84b7e585969 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -56,18 +56,6 @@ limitations under the License. #include "xla/util.h" namespace xla { - -#if defined(PLATFORM_GOOGLE) -FusionDecision::FusionDecision(bool decision, - absl::SourceLocation source_location) { - if (!decision) { - explanation_ = - absl::StrCat("Not fusing: due to ", source_location.file_name(), ":", - source_location.line()); - } -} -#endif // PLATFORM_GOOGLE - namespace { // These nodes can always be duplicated into consumers, even if diff --git a/third_party/xla/xla/service/instruction_fusion.h b/third_party/xla/xla/service/instruction_fusion.h index 85ff4dde04035c..d5ad1b7c17e1a6 100644 --- a/third_party/xla/xla/service/instruction_fusion.h +++ b/third_party/xla/xla/service/instruction_fusion.h @@ -21,18 +21,14 @@ limitations under the License. #include #include #include -#include #include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/service/hlo_module_config.h" -#include "tsl/platform/macros.h" // The source_location.h is not available in open source. #if defined(PLATFORM_GOOGLE) #include "absl/types/source_location.h" @@ -54,6 +50,29 @@ struct InPlaceFusionOptions { bool relax_multiple_non_elementwise_ops = false; }; +// A holder for the source location. absl::SourceLocation is not available in +// open source, so we have a stub implementation to limit +// #if define(PLATFORM_GOOGLE). +class SourceLocationHolder { + public: +#if defined(PLATFORM_GOOGLE) + explicit constexpr SourceLocationHolder( + absl::SourceLocation source_location = absl::SourceLocation::current()) + : source_location_(source_location) {} + + std::string ToString() const { + return absl::StrCat(" at: ", source_location_.file_name(), ":", + source_location_.line()); + } + + private: + absl::SourceLocation source_location_; +#else + SourceLocationHolder() = default; + std::string ToString() const { return ""; } +#endif // PLATFORM_GOOGLE +}; + // Propagating explanation of fusion decisions: if something could not be fused, // explain the reason. class FusionDecision { @@ -61,34 +80,29 @@ class FusionDecision { static FusionDecision Allow() { return FusionDecision(); } FusionDecision(const FusionDecision& decision) = default; -#if defined(PLATFORM_GOOGLE) - static std::string LocToString(absl::SourceLocation source_location) { - return absl::StrCat(" at: ", source_location.file_name(), ":", - source_location.line()); - } static FusionDecision Forbid( absl::string_view explanation, - absl::SourceLocation source_location = absl::SourceLocation::current()) { - return FusionDecision( - absl::StrCat(explanation, LocToString(source_location))); + SourceLocationHolder source_location = SourceLocationHolder()) { + return FusionDecision(false, explanation, source_location); } // If condition is `true` means that we CAN fuse. In that case, explanation is // discarded. FusionDecision( bool condition, absl::string_view explanation, - absl::SourceLocation source_location = absl::SourceLocation::current()) { + SourceLocationHolder source_location = SourceLocationHolder()) { if (!condition) { - explanation_ = absl::StrCat(explanation, LocToString(source_location)); + explanation_ = explanation; + source_location_ = source_location; } } explicit FusionDecision( absl::Status status, - absl::SourceLocation source_location = absl::SourceLocation::current()) { + SourceLocationHolder source_location = SourceLocationHolder()) { if (!status.ok()) { - explanation_ = - absl::StrCat(status.message(), LocToString(source_location)); + explanation_ = status.message(); + source_location_ = source_location; } } @@ -97,25 +111,8 @@ class FusionDecision { // provide explicit explanation. FusionDecision( // NOLINT bool decision, - absl::SourceLocation source_location = absl::SourceLocation::current()); -#else - // If condition is `true` means that we CAN fuse. In that case, explanation is - // discarded. - FusionDecision(bool condition, absl::string_view explanation) { - if (!condition) { - explanation_ = std::string(explanation); - } - } - static FusionDecision Forbid(absl::string_view explanation) { - return FusionDecision(explanation); - } - explicit FusionDecision(absl::Status status) { - if (!status.ok()) { - explanation_ = status.message(); - } - } - -#endif // PLATFORM_GOOGLE + SourceLocationHolder source_location = SourceLocationHolder()) + : FusionDecision(decision, "Not fusing", source_location) {} // Returns whether it can be fused. explicit operator bool() const { return CanFuse(); } @@ -130,8 +127,7 @@ class FusionDecision { if (CanFuse() || decision.CanFuse()) { return Allow(); } - return Forbid( - absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())); + return Forbid(absl::StrCat(Explain(), " ; ", decision.Explain())); } // Connects two fusion decision with a conjunction. Unlike disjunction, @@ -150,30 +146,30 @@ class FusionDecision { // Appends to explanation, or turns the decision negative. FusionDecision operator<<(absl::string_view explanation) const { - return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); + return Forbid(absl::StrCat(explanation_.value_or(""), explanation), + source_location_); } // Appends to explanation, or turns the decision negative. FusionDecision operator<<(int64_t explanation) const { - return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); + return Forbid(absl::StrCat(explanation_.value_or(""), explanation), + source_location_); } // Explains why the fusion could not be performed, or that it can be. std::string Explain() const { - return explanation_.value_or("Actually, we can fuse it."); + if (explanation_.has_value()) { + return absl::StrCat(explanation_.value(), source_location_.ToString()); + } + return "Actually, we can fuse it."; } private: // Empty IFF fusion is possible (explanation provided for negative cases). std::optional explanation_; + SourceLocationHolder source_location_; FusionDecision() = default; - - explicit FusionDecision(absl::string_view explanation) - : explanation_(explanation) {} - - explicit FusionDecision(const char* explanation) - : explanation_(explanation) {} }; #define RETURN_IF_NOT_FUSIBLE(...) \ From 83a265ddfe3edcf5fc8d5f84900543d7fca4c2e4 Mon Sep 17 00:00:00 2001 From: Yurii Topin Date: Mon, 8 Dec 2025 08:35:06 -0800 Subject: [PATCH 031/753] Reverts 31228b49f1c3af6f784556a1845782a3969358d6 PiperOrigin-RevId: 841777859 --- third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD b/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD index 83cca313adf4f5..3a079c87ab9dd6 100644 --- a/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD +++ b/third_party/xla/third_party/mkl_dnn/mkldnn_acl.BUILD @@ -156,13 +156,5 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@compute_library//:arm_compute", - ] + select({ - # When using MKL-DNN on the AArch64 architecture, OpenMP is required - # for parallelization. Because the Hermetic C++ build environment uses - # the -nodefaultlibs flag, simply passing -fopenmp is insufficient. - # OpenMP's dependencies must be explicitly linked to ensure correct - # inclusion, as automatic linking is disabled. - "@rules_ml_toolchain//common:is_hermetic_cc_enabled": ["@rules_ml_toolchain//cc/sysroots:openmp"], - "//conditions:default": [], - }), + ], ) From 79cd71f875f583563fcbfeb04ad0660ed5f997fa Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 8 Dec 2025 08:51:28 -0800 Subject: [PATCH 032/753] [xla:gpu] Switch CollectiveMetadataThunk to GpuCliqueRendezvous PiperOrigin-RevId: 841783861 --- .../gpu/collectives/gpu_clique_rendezvous.cc | 4 +- .../gpu/collectives/gpu_clique_rendezvous.h | 16 +-- .../collectives/gpu_clique_rendezvous_test.cc | 4 +- .../xla/xla/backends/gpu/runtime/BUILD | 1 + .../runtime/collective_kernel_thunk_test.cc | 2 +- .../gpu/runtime/collective_metadata_thunk.cc | 105 ++++-------------- 6 files changed, 37 insertions(+), 95 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.cc index 277cea6198ae80..fcb87afc21d28e 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.cc @@ -65,8 +65,8 @@ struct RankFormatter { } // namespace GpuCliqueRendezvous::GpuCliqueRendezvous( - GpuCliqueKey clique_key, absl::btree_map state) - : clique_key_(std::move(clique_key)), state_(std::move(state)) {} + GpuCliqueKey clique_key, absl::btree_map values) + : clique_key_(std::move(clique_key)), values_(std::move(values)) {} absl::StatusOr> GpuCliqueRendezvous::Join( const GpuCliqueKey& clique_key, RankId rank, std::any data) { diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.h b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.h index a3220996f214c5..623cd7d8513fd7 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.h +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous.h @@ -42,15 +42,15 @@ class GpuCliqueRendezvous { static absl::StatusOr> Join( const GpuCliqueKey& clique_key, RankId rank, std::any data); - // Returns the clique key associated with this data. + // Returns the clique key associated with this rendezvous object. const GpuCliqueKey& clique_key() const { return clique_key_; } - // Returns the state associated with the given rank. If state type is not - // the same as `T`, returns an error. + // Returns the value at the given rank. If value type is not the same as `T`, + // returns an error. template - absl::StatusOr> state(RankId rank) const { - auto it = state_.find(rank); - if (it == state_.end()) { + absl::StatusOr> at(RankId rank) const { + auto it = values_.find(rank); + if (it == values_.end()) { return NotFound("Data not found for rank %d", rank.value()); } @@ -64,10 +64,10 @@ class GpuCliqueRendezvous { private: GpuCliqueRendezvous(GpuCliqueKey clique_key, - absl::btree_map state); + absl::btree_map values); GpuCliqueKey clique_key_; - absl::btree_map state_; + absl::btree_map values_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous_test.cc b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous_test.cc index 58d0bd7b2c402a..ef8a2ce09383c0 100644 --- a/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous_test.cc +++ b/third_party/xla/xla/backends/gpu/collectives/gpu_clique_rendezvous_test.cc @@ -48,8 +48,8 @@ TEST(GpuCliqueRendezvousTest, TwoParticipants) { GpuCliqueRendezvous& data = **rendezvous; ASSERT_EQ(data.clique_key(), key); - ASSERT_EQ(*data.state(RankId(0)), 0); - ASSERT_EQ(*data.state(RankId(1)), 1); + ASSERT_EQ(*data.at(RankId(0)), 0); + ASSERT_EQ(*data.at(RankId(1)), 1); }; }; diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 2c79ce55668f95..c40f279f9212dd 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -2010,6 +2010,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", + "//xla/backends/gpu/collectives:gpu_clique_rendezvous", "//xla/core/collectives:rank_id", "//xla/hlo/ir:hlo", "//xla/runtime:device_id", diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc index 993b4d0cc06b0d..e65d7760e8981f 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_kernel_thunk_test.cc @@ -394,7 +394,7 @@ TEST(CollectiveKernelThunkTest, MultiprocessTest) { for (absl::StatusOr result : RunCollectiveKernelThunkOnDevices(metadata, /*emulate_multiprocess=*/true)) { - EXPECT_THAT(result, StatusIs(absl::StatusCode::kUnimplemented)); + EXPECT_THAT(result, StatusIs(absl::StatusCode::kInvalidArgument)); } } diff --git a/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc index ae4757dec337fa..a5e44c890f34fe 100644 --- a/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/collective_metadata_thunk.cc @@ -15,23 +15,20 @@ limitations under the License. #include "xla/backends/gpu/runtime/collective_metadata_thunk.h" +#include #include #include #include #include -#include #include #include -#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "absl/types/span.h" #include "google/protobuf/repeated_ptr_field.h" #include "xla/backends/gpu/collectives/gpu_clique_key.h" +#include "xla/backends/gpu/collectives/gpu_clique_rendezvous.h" #include "xla/backends/gpu/runtime/collective_multimem.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/core/collectives/rank_id.h" @@ -39,7 +36,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/runtime/device_id.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/rendezvous.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/gpu/collective_kernel_metadata.h" @@ -80,98 +76,43 @@ CollectiveConfig CollectiveMetadataThunk::GetCollectiveConfig( return config; } -struct DeviceParameters { - RankId rank; - std::vector parameters; - - bool operator<(const DeviceParameters& other) const { - return rank < other.rank; - } -}; - -absl::StatusOr> SyncLocalDeviceParameters( - const GpuCliqueKey& clique_key, RankId rank, - std::vector parameters) { - std::vector device_parameters; - auto rendezvous_fn = [](absl::Span values) { - std::vector values_copy; - for (const auto& value : values) { - values_copy.push_back(*value); - } - // Sort to make sure that values are in the same order as the - // devices are ordered in the communicator. - absl::c_sort(values_copy); - return values_copy; - }; - - std::string start_rendezvous_key = absl::StrFormat( - "[rank=%d] Initializing collective metadata for clique %s", rank.value(), - clique_key.ToString()); - - DeviceParameters params; - params.rank = rank; - params.parameters = std::move(parameters); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr> local_ranks_parameters, - Rendezvous>( - /*name=*/start_rendezvous_key, /*key=*/clique_key, - /*value=*/params, - /*num_threads=*/clique_key.num_local_participants(), rendezvous_fn)); - return std::vector(local_ranks_parameters->begin(), - local_ranks_parameters->end()); -} - -absl::StatusOr> SyncGlobalDeviceParameters( - const GpuCliqueKey& clique_key, RankId rank, - std::vector parameters) { - if (!clique_key.is_local()) { - return Unimplemented( - "[rank=%d] Multiprocess collective metadata is not supported yet in " - "clique %s", - rank.value(), clique_key.ToString()); - } - - TF_ASSIGN_OR_RETURN( - std::vector local_ranks_parameters, - SyncLocalDeviceParameters(clique_key, rank, std::move(parameters))); - - return local_ranks_parameters; -} - absl::Status CollectiveMetadataThunk::ConstructCollectiveMetadata( const GpuCliqueKey& clique_key, RankId rank, se::Stream* stream, std::vector parameters, std::shared_ptr multimem, se::DeviceAddressBase destination) { - CollectiveKernelMetadata metadata; - metadata.rank = rank.value(); - metadata.multicast_buffer_ptr = - multimem ? multimem->mapped_ptr(rank) : nullptr; + size_t num_parameters = parameters.size(); + + using DeviceParameters = std::vector; + + // Exchange device parameters with all ranks in the clique. TF_ASSIGN_OR_RETURN( - std::vector device_parameters, - SyncGlobalDeviceParameters(clique_key, rank, std::move(parameters))); - TF_RET_CHECK(!device_parameters.empty()) - << "Not enough devices in the clique."; - const size_t num_parameters = device_parameters[0].parameters.size(); - for (const auto& value : device_parameters) { - TF_RET_CHECK(value.parameters.size() == num_parameters); - } + auto device_parameters, + GpuCliqueRendezvous::Join(clique_key, rank, std::move(parameters))); + // Collect pointers to device buffers from all participating ranks. std::vector param_to_peers_ptrs; - param_to_peers_ptrs.reserve(device_parameters.size() * num_parameters); - for (int peer = 0; peer < device_parameters.size(); ++peer) { - for (int param = 0; param < num_parameters; ++param) { - param_to_peers_ptrs.push_back( - device_parameters[peer].parameters[param].opaque()); + for (auto peer = RankId(0); peer < RankId(clique_key.num_devices()); ++peer) { + TF_ASSIGN_OR_RETURN(const DeviceParameters& peer_parameters, + device_parameters->at(peer)); + for (se::DeviceAddressBase peer_parameter : peer_parameters) { + param_to_peers_ptrs.push_back(peer_parameter.opaque()); } } + // Check that all participants have the same number of parameters. + TF_RET_CHECK(param_to_peers_ptrs.size() == + num_parameters * clique_key.num_local_participants()); + const int64_t param_to_peers_ptrs_size = param_to_peers_ptrs.size() * sizeof(void*); se::DeviceAddressBase param_to_peers_ptrs_buffer = destination.GetByteSlice( sizeof(CollectiveKernelMetadata), param_to_peers_ptrs_size); + CollectiveKernelMetadata metadata; + metadata.rank = rank.value(); + metadata.multicast_buffer_ptr = + multimem ? multimem->mapped_ptr(rank) : nullptr; metadata.param_to_peers = reinterpret_cast(param_to_peers_ptrs_buffer.opaque()); From eed57fa340139e372413b2695de32db1ae75fe8a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 8 Dec 2025 08:53:20 -0800 Subject: [PATCH 033/753] [xla] Prepare to MaybeOwningDeviceAddress migration PiperOrigin-RevId: 841784507 --- third_party/xla/xla/service/BUILD | 27 ++++-- ...mory.cc => maybe_owning_device_address.cc} | 30 +++---- .../xla/service/maybe_owning_device_address.h | 88 +++++++++++++++++++ ...cc => maybe_owning_device_address_test.cc} | 15 ++-- .../xla/service/maybe_owning_device_memory.h | 72 ++------------- 5 files changed, 134 insertions(+), 98 deletions(-) rename third_party/xla/xla/service/{maybe_owning_device_memory.cc => maybe_owning_device_address.cc} (53%) create mode 100644 third_party/xla/xla/service/maybe_owning_device_address.h rename third_party/xla/xla/service/{maybe_owning_device_memory_test.cc => maybe_owning_device_address_test.cc} (77%) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b8b6a62ec06674..8973a3631eccdf 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4080,28 +4080,39 @@ xla_cc_test( ) cc_library( - name = "maybe_owning_device_memory", - srcs = ["maybe_owning_device_memory.cc"], - hdrs = ["maybe_owning_device_memory.h"], + name = "maybe_owning_device_address", + srcs = ["maybe_owning_device_address.cc"], + hdrs = ["maybe_owning_device_address.h"], deps = [ + "//xla:types", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/base:core_headers", ], ) xla_cc_test( - name = "maybe_owning_device_memory_test", - srcs = ["maybe_owning_device_memory_test.cc"], + name = "maybe_owning_device_address_test", + srcs = ["maybe_owning_device_address_test.cc"], deps = [ - ":maybe_owning_device_memory", + ":maybe_owning_device_address", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", ], ) +cc_library( + name = "maybe_owning_device_memory", + hdrs = ["maybe_owning_device_memory.h"], + deps = [ + ":maybe_owning_device_address", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/base:core_headers", + ], +) + cc_library( name = "float8_fnuz_ir_emitter", srcs = [ diff --git a/third_party/xla/xla/service/maybe_owning_device_memory.cc b/third_party/xla/xla/service/maybe_owning_device_address.cc similarity index 53% rename from third_party/xla/xla/service/maybe_owning_device_memory.cc rename to third_party/xla/xla/service/maybe_owning_device_address.cc index a7b3aa5e4b641c..6f8e252ebac99d 100644 --- a/third_party/xla/xla/service/maybe_owning_device_memory.cc +++ b/third_party/xla/xla/service/maybe_owning_device_address.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include #include @@ -25,33 +25,29 @@ limitations under the License. namespace xla { -stream_executor::DeviceAddressBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() - const { +se::DeviceAddressBase MaybeOwningDeviceAddress::AsDeviceAddress() const { if (HasOwnership()) { - return *std::get>(mem_); + return *std::get>(mem_); } - return std::get(mem_); + return std::get(mem_); } -bool MaybeOwningDeviceMemory::HasOwnership() const { - return std::holds_alternative>( - mem_); +bool MaybeOwningDeviceAddress::HasOwnership() const { + return std::holds_alternative>(mem_); } -std::optional> -MaybeOwningDeviceMemory::Release() { +std::optional> +MaybeOwningDeviceAddress::Release() { if (!HasOwnership()) { return {}; } - return std::move( - std::get>(mem_)); + return std::move(std::get>(mem_)); } -const stream_executor::ScopedDeviceAddress* -MaybeOwningDeviceMemory::AsOwningDeviceMemory() const { - return HasOwnership() - ? &std::get>(mem_) - : nullptr; +const se::ScopedDeviceAddress* +MaybeOwningDeviceAddress::AsScopedDeviceAddress() const { + return HasOwnership() ? &std::get>(mem_) + : nullptr; } } // namespace xla diff --git a/third_party/xla/xla/service/maybe_owning_device_address.h b/third_party/xla/xla/service/maybe_owning_device_address.h new file mode 100644 index 00000000000000..8a6f52e15adcaf --- /dev/null +++ b/third_party/xla/xla/service/maybe_owning_device_address.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MAYBE_OWNING_DEVICE_ADDRESS_H_ +#define XLA_SERVICE_MAYBE_OWNING_DEVICE_ADDRESS_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" +#include "xla/types.h" // IWYU pragma: keep + +namespace xla { + +// MaybeOwningDeviceAddress represents either an owned or unowned device +// address. Like std::variant, DeviceMemory>. +// When the object goes output of scope, it will free the underlying device +// address if it owns it. +class MaybeOwningDeviceAddress { + public: + MaybeOwningDeviceAddress() = default; + MaybeOwningDeviceAddress(MaybeOwningDeviceAddress&&) = default; + MaybeOwningDeviceAddress& operator=(MaybeOwningDeviceAddress&&) = default; + + explicit MaybeOwningDeviceAddress(se::ScopedDeviceAddress owned) + : mem_(std::move(owned)) {} + + explicit MaybeOwningDeviceAddress(se::DeviceAddressBase unowned) + : mem_(unowned) {} + + MaybeOwningDeviceAddress& operator=(se::DeviceAddressBase unowned) { + mem_ = unowned; + return *this; + } + + MaybeOwningDeviceAddress& operator=(se::ScopedDeviceAddress owned) { + mem_ = std::move(owned); + return *this; + } + + // Fetches the underlying DeviceAddressBase. The caller of this function is + // *not* responsible for freeing the address. + se::DeviceAddressBase AsDeviceAddress() const; + + // Release the se::ScopedDeviceAddress without freeing + // it, and moves the ownership of the address from the object to the caller. + // + // A nullopt is returned if the HasOwnership() == false; + std::optional> Release(); + + // If the device address is owned, returns a pointer to the internal + // ScopedDeviceAddress, otherwise nullptr is returned. + const se::ScopedDeviceAddress* AsScopedDeviceAddress() const; + + ABSL_DEPRECATE_AND_INLINE() + se::DeviceAddressBase AsDeviceMemoryBase() const { return AsDeviceAddress(); } + + ABSL_DEPRECATE_AND_INLINE() + const se::ScopedDeviceAddress* AsOwningDeviceMemory() const { + return AsScopedDeviceAddress(); + } + + // Returns true if has ownership over underlying address. + bool HasOwnership() const; + + private: + std::variant> mem_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_MAYBE_OWNING_DEVICE_ADDRESS_H_ diff --git a/third_party/xla/xla/service/maybe_owning_device_memory_test.cc b/third_party/xla/xla/service/maybe_owning_device_address_test.cc similarity index 77% rename from third_party/xla/xla/service/maybe_owning_device_memory_test.cc rename to third_party/xla/xla/service/maybe_owning_device_address_test.cc index 2d3a5a8cf38708..d2dbcc46aad3ca 100644 --- a/third_party/xla/xla/service/maybe_owning_device_memory_test.cc +++ b/third_party/xla/xla/service/maybe_owning_device_address_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" @@ -21,13 +21,14 @@ limitations under the License. namespace xla { namespace { -using MaybeOwningDeviceMemoryTest = ::testing::Test; +using MaybeOwningDeviceAddressTest = ::testing::Test; -TEST(MaybeOwningDeviceMemoryTest, DefaultConstructed) { - MaybeOwningDeviceMemory memory; +TEST(MaybeOwningDeviceAddressTest, DefaultConstructed) { + MaybeOwningDeviceAddress memory; EXPECT_FALSE(memory.HasOwnership()); - EXPECT_EQ(memory.AsDeviceMemoryBase().opaque(), nullptr); - EXPECT_EQ(memory.AsDeviceMemoryBase().size(), 0); + + EXPECT_EQ(memory.AsDeviceAddress().opaque(), nullptr); + EXPECT_EQ(memory.AsDeviceAddress().size(), 0); } //===-----------------------------------------------------------------------===/ @@ -36,7 +37,7 @@ TEST(MaybeOwningDeviceMemoryTest, DefaultConstructed) { void BM_DefaultConstructed(benchmark::State& state) { for (auto s : state) { - MaybeOwningDeviceMemory memory; + MaybeOwningDeviceAddress memory; benchmark::DoNotOptimize(memory); } } diff --git a/third_party/xla/xla/service/maybe_owning_device_memory.h b/third_party/xla/xla/service/maybe_owning_device_memory.h index 8f7b33b4d2d66e..897003ffb17429 100644 --- a/third_party/xla/xla/service/maybe_owning_device_memory.h +++ b/third_party/xla/xla/service/maybe_owning_device_memory.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The OpenXLA Authors. +/* Copyright 2025 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,76 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ #define XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ -#include -#include -#include -#include - -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "absl/base/macros.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/stream_executor/device_memory.h" // IWYU pragma: keep #include "xla/stream_executor/device_memory_allocator.h" // IWYU pragma: keep namespace xla { -// MaybeOwningDeviceMemory represents either an owned or unowned -// device memory. Like std::variant, -// DeviceMemory>. When the object goes output of scope, it will free the -// underlying memory if it owns it. -class MaybeOwningDeviceMemory { - public: - MaybeOwningDeviceMemory() = default; - ~MaybeOwningDeviceMemory() = default; - - explicit MaybeOwningDeviceMemory( - stream_executor::ScopedDeviceAddress owned) - : mem_(std::move(owned)) {} - - explicit MaybeOwningDeviceMemory(stream_executor::DeviceAddressBase unowned) - : mem_(unowned) {} - - MaybeOwningDeviceMemory(MaybeOwningDeviceMemory&&) = default; - - MaybeOwningDeviceMemory& operator=( - stream_executor::DeviceAddressBase unowned) { - mem_ = unowned; - return *this; - } - - MaybeOwningDeviceMemory& operator=( - stream_executor::ScopedDeviceAddress owned) { - mem_ = std::move(owned); - return *this; - } - - MaybeOwningDeviceMemory& operator=(MaybeOwningDeviceMemory&&) = default; - - // Fetches the underlying DeviceAddressBase from a - // MaybeOwningDeviceMemory. The caller of this function is *not* - // responsible for freeing the memory. - stream_executor::DeviceAddressBase AsDeviceMemoryBase() const; - - // Release the stream_executor::ScopedDeviceAddress without freeing - // it, and moves the ownership of the memory buffer from the object to the - // caller. - // - // A nullopt is returned if the HasOwnership() == false; - std::optional> Release(); - - // If the device memory is owned, returns a pointer to the internal - // OwningDeviceMemory, otherwise nullptr is returned. - const stream_executor::ScopedDeviceAddress* AsOwningDeviceMemory() - const; - - // Returns true if the device_memory has ownership over underlying memory. - bool HasOwnership() const; - - private: - std::variant> - mem_; -}; +using MaybeOwningDeviceMemory ABSL_DEPRECATE_AND_INLINE() = + MaybeOwningDeviceAddress; -} // namespace xla +} #endif // XLA_SERVICE_MAYBE_OWNING_DEVICE_MEMORY_H_ From bacbd6b4513f2e019a287a89cf5da5e521ff6fe0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 09:26:48 -0800 Subject: [PATCH 034/753] PR #105775: Bump urllib3 from 2.5.0 to 2.6.0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/105775 Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.5.0 to 2.6.0.
Release notes

Sourced from urllib3's releases.

2.6.0

🚀 urllib3 is fundraising for HTTP/2 support

urllib3 is raising ~$40,000 USD to release HTTP/2 support and ensure long-term sustainable maintenance of the project after a sharp decline in financial support. If your company or organization uses Python and would benefit from HTTP/2 support in Requests, pip, cloud SDKs, and thousands of other projects please consider contributing financially to ensure HTTP/2 support is developed sustainably and maintained for the long-haul.

Thank you for your support.

Security

  • Fixed a security issue where streaming API could improperly handle highly compressed HTTP content ("decompression bombs") leading to excessive resource consumption even when a small amount of data was requested. Reading small chunks of compressed data is safer and much more efficient now. (CVE-2025-66471 reported by @​Cycloctane, 8.9 High, GHSA-2xpw-w6gg-jr37)
  • Fixed a security issue where an attacker could compose an HTTP response with virtually unlimited links in the Content-Encoding header, potentially leading to a denial of service (DoS) attack by exhausting system resources during decoding. The number of allowed chained encodings is now limited to 5. (CVE-2025-66418 reported by @​illia-v, 8.9 High, GHSA-gm62-xv2j-4w53)

[!IMPORTANT]

  • If urllib3 is not installed with the optional urllib3[brotli] extra, but your environment contains a Brotli/brotlicffi/brotlipy package anyway, make sure to upgrade it to at least Brotli 1.2.0 or brotlicffi 1.2.0.0 to benefit from the security fixes and avoid warnings. Prefer using urllib3[brotli] to install a compatible Brotli package automatically.
  • If you use custom decompressors, please make sure to update them to respect the changed API of urllib3.response.ContentDecoder.

Features

  • Enabled retrieval, deletion, and membership testing in HTTPHeaderDict using bytes keys. (#3653)
  • Added host and port information to string representations of HTTPConnection. (#3666)
  • Added support for Python 3.14 free-threading builds explicitly. (#3696)

Removals

  • Removed the HTTPResponse.getheaders() method in favor of HTTPResponse.headers. Removed the HTTPResponse.getheader(name, default) method in favor of HTTPResponse.headers.get(name, default). (#3622)

Bugfixes

  • Fixed redirect handling in urllib3.PoolManager when an integer is passed for the retries parameter. (#3649)
  • Fixed HTTPConnectionPool when used in Emscripten with no explicit port. (#3664)
  • Fixed handling of SSLKEYLOGFILE with expandable variables. (#3700)

Misc

  • Changed the zstd extra to install backports.zstd instead of zstandard on Python 3.13 and before. (#3693)
  • Improved the performance of content decoding by optimizing BytesQueueBuffer class. (#3710)
  • Allowed building the urllib3 package with newer setuptools-scm v9.x. (#3652)
  • Ensured successful urllib3 builds by setting Hatchling requirement to ≥ 1.27.0. (#3638)
Changelog

Sourced from urllib3's changelog.

2.6.0 (2025-12-05)

Security

  • Fixed a security issue where streaming API could improperly handle highly compressed HTTP content ("decompression bombs") leading to excessive resource consumption even when a small amount of data was requested. Reading small chunks of compressed data is safer and much more efficient now. (GHSA-2xpw-w6gg-jr37 <https://github.com/urllib3/urllib3/security/advisories/GHSA-2xpw-w6gg-jr37>__)
  • Fixed a security issue where an attacker could compose an HTTP response with virtually unlimited links in the Content-Encoding header, potentially leading to a denial of service (DoS) attack by exhausting system resources during decoding. The number of allowed chained encodings is now limited to 5. (GHSA-gm62-xv2j-4w53 <https://github.com/urllib3/urllib3/security/advisories/GHSA-gm62-xv2j-4w53>__)

.. caution::

  • If urllib3 is not installed with the optional urllib3[brotli] extra, but your environment contains a Brotli/brotlicffi/brotlipy package anyway, make sure to upgrade it to at least Brotli 1.2.0 or brotlicffi 1.2.0.0 to benefit from the security fixes and avoid warnings. Prefer using urllib3[brotli] to install a compatible Brotli package automatically.

  • If you use custom decompressors, please make sure to update them to respect the changed API of urllib3.response.ContentDecoder.

Features

  • Enabled retrieval, deletion, and membership testing in HTTPHeaderDict using bytes keys. ([#3653](https://github.com/urllib3/urllib3/issues/3653) <https://github.com/urllib3/urllib3/issues/3653>__)
  • Added host and port information to string representations of HTTPConnection. ([#3666](https://github.com/urllib3/urllib3/issues/3666) <https://github.com/urllib3/urllib3/issues/3666>__)
  • Added support for Python 3.14 free-threading builds explicitly. ([#3696](https://github.com/urllib3/urllib3/issues/3696) <https://github.com/urllib3/urllib3/issues/3696>__)

Removals

  • Removed the HTTPResponse.getheaders() method in favor of HTTPResponse.headers. Removed the HTTPResponse.getheader(name, default) method in favor of HTTPResponse.headers.get(name, default). ([#3622](https://github.com/urllib3/urllib3/issues/3622) <https://github.com/urllib3/urllib3/issues/3622>__)

Bugfixes

  • Fixed redirect handling in urllib3.PoolManager when an integer is passed for the retries parameter. ([#3649](https://github.com/urllib3/urllib3/issues/3649) <https://github.com/urllib3/urllib3/issues/3649>__)
  • Fixed HTTPConnectionPool when used in Emscripten with no explicit port. ([#3664](https://github.com/urllib3/urllib3/issues/3664) <https://github.com/urllib3/urllib3/issues/3664>__)
  • Fixed handling of SSLKEYLOGFILE with expandable variables. ([#3700](https://github.com/urllib3/urllib3/issues/3700) <https://github.com/urllib3/urllib3/issues/3700>__)

... (truncated)

Commits
  • 720f484 Release 2.6.0
  • 24d7b67 Merge commit from fork
  • c19571d Merge commit from fork
  • 816fcf0 Bump actions/setup-python from 6.0.0 to 6.1.0 (#3725)
  • 18af0a1 Improve speed of BytesQueueBuffer.get() by using memoryview (#3711)
  • 1f6abac Bump versions of pre-commit hooks (#3716)
  • 1c8fbf7 Bump actions/checkout from 5.0.0 to 6.0.0 (#3722)
  • 7784b9e Add Python 3.15 to CI (#3717)
  • 0241c9e Updated docs to reflect change in optional zstd dependency from zstandard t...
  • 7afcabb Expand environment variable of SSLKEYLOGFILE (#3705)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=urllib3&package-manager=pip&previous-version=2.5.0&new-version=2.6.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/tensorflow/tensorflow/network/alerts).
Copybara import of the project: -- 8e0e52510295ad7244fb8fa46e001f6544d94122 by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>: Bump urllib3 from 2.5.0 to 2.6.0 Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.5.0 to 2.6.0. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.5.0...2.6.0) --- updated-dependencies: - dependency-name: urllib3 dependency-version: 2.6.0 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Merging this change closes #105775 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/105775 from tensorflow:dependabot/pip/urllib3-2.6.0 8e0e52510295ad7244fb8fa46e001f6544d94122 PiperOrigin-RevId: 841796827 --- requirements_lock_3_10.txt | 6 +++--- requirements_lock_3_11.txt | 6 +++--- requirements_lock_3_12.txt | 6 +++--- requirements_lock_3_13.txt | 6 +++--- requirements_lock_3_9.txt | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 36a6e6b78b5604..486c66c2fdb52f 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -748,9 +748,9 @@ typing-extensions==4.14.1 \ # -r ci/official/requirements_updater/requirements.in # optree # rich -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 6238e70c957632..80a1a2e834b3c2 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -747,9 +747,9 @@ typing-extensions==4.14.1 \ # via # -r ci/official/requirements_updater/requirements.in # optree -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 2d655921b2f9d8..ac1fb6ff141e7d 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -747,9 +747,9 @@ typing-extensions==4.14.1 \ # via # -r ci/official/requirements_updater/requirements.in # optree -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt index 45461447246243..4e0988a88aff92 100644 --- a/requirements_lock_3_13.txt +++ b/requirements_lock_3_13.txt @@ -729,9 +729,9 @@ typing-extensions==4.14.1 \ # via # -r ci/official/requirements_updater/requirements.in # optree -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 26e2d0ae19171b..6a52f5d70bdcae 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -734,9 +734,9 @@ typing-extensions==4.14.1 \ # -r ci/official/requirements_updater/requirements.in # optree # rich -urllib3==2.5.0 \ - --hash=sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760 \ - --hash=sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc +urllib3==2.6.0 \ + --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ + --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests werkzeug==3.1.3 \ --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ From 7785057a2ac37f900181eb7720bcfe68a1d44b1c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Dec 2025 17:43:36 +0000 Subject: [PATCH 035/753] Bump werkzeug from 3.1.3 to 3.1.4 Bumps [werkzeug](https://github.com/pallets/werkzeug) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/werkzeug/releases) - [Changelog](https://github.com/pallets/werkzeug/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/werkzeug/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: werkzeug dependency-version: 3.1.4 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements_lock_3_10.txt | 6 +++--- requirements_lock_3_11.txt | 6 +++--- requirements_lock_3_12.txt | 6 +++--- requirements_lock_3_13.txt | 6 +++--- requirements_lock_3_9.txt | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index 486c66c2fdb52f..a2645ee5ddbdb4 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -752,9 +752,9 @@ urllib3==2.6.0 \ --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index 80a1a2e834b3c2..cd51c5e0c0c338 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -751,9 +751,9 @@ urllib3==2.6.0 \ --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index ac1fb6ff141e7d..1b8d63c9d75147 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -751,9 +751,9 @@ urllib3==2.6.0 \ --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt index 4e0988a88aff92..ded80d5230a8c9 100644 --- a/requirements_lock_3_13.txt +++ b/requirements_lock_3_13.txt @@ -733,9 +733,9 @@ urllib3==2.6.0 \ --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 6a52f5d70bdcae..6e68ddf6f79595 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -738,9 +738,9 @@ urllib3==2.6.0 \ --hash=sha256:c90f7a39f716c572c4e3e58509581ebd83f9b59cced005b7db7ad2d22b0db99f \ --hash=sha256:cb9bcef5a4b345d5da5d145dc3e30834f58e8018828cbc724d30b4cb7d4d49f1 # via requests -werkzeug==3.1.3 \ - --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ - --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 +werkzeug==3.1.4 \ + --hash=sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905 \ + --hash=sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e # via tb-nightly wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ From 1f0f883a4c3eec8c320189a212c43444817cb298 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 8 Dec 2025 09:30:20 -0800 Subject: [PATCH 036/753] Gracefully handle missing parameter/output shardings from single-device atom executables `PjRtExecutable::GetParameterShardings()` and `PjRtExecutable::GetOutputShardings()` may return `std::nullopt` for single-device executables, but their shardings are trivial and we can infer them from the device count. PiperOrigin-RevId: 841798186 --- .../xla/xla/python/ifrt/ir/tests/ifrt-opt.cc | 11 ++++- .../ifrt_compile_and_propagate_shardings.mlir | 30 ++++++++++++ ...rt_compile_and_propagate_shardings_pass.cc | 29 +++++++---- ...y_bound_external_loaded_executable_pass.cc | 48 +++++++++---------- 4 files changed, 84 insertions(+), 34 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc index 596767a9dc3a1d..097136f5b1c631 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -67,11 +67,18 @@ class TestChildExecutableCompiler : public AtomProgramCompiler { "invalidated some method string_views."; auto mock_executable = std::make_unique>(); + int num_devices; + if (options.executable_build_options.has_device_assignment()) { + num_devices = + options.executable_build_options.device_assignment().num_elements(); + } else { + num_devices = 1; + } int num_parameters_to_propagate = options.executable_build_options .allow_spmd_sharding_propagation_to_parameters() .size(); - if (num_parameters_to_propagate > 0) { + if (num_devices > 1 && num_parameters_to_propagate > 0) { xla::OpSharding op_sharding; op_sharding.set_type(xla::OpSharding::REPLICATED); std::vector parameter_shardings( @@ -83,7 +90,7 @@ class TestChildExecutableCompiler : public AtomProgramCompiler { options.executable_build_options .allow_spmd_sharding_propagation_to_output() .size(); - if (num_outputs_to_propagate > 0) { + if (num_devices > 1 && num_outputs_to_propagate > 0) { // Always infer output shardings to be replicated for the lit tests. xla::OpSharding op_sharding; op_sharding.set_type(xla::OpSharding::REPLICATED); diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir index 4021496168cb8c..e8c49c453b6853 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir @@ -286,6 +286,36 @@ module @propagate_to_inputs { // ----- +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0]> +// CHECK-LABEL: @propagate_single_device +module @propagate_single_device { + func.func @main(%arg0: !array_unspecified) + -> !array_unspecified attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + return %0 : !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + +} + +// ----- + !array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> !array_unspecified = !ifrt.array, diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc index c48128130c59f7..a828664243a673 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc @@ -62,6 +62,7 @@ limitations under the License. #include "xla/python/ifrt/ir/transforms/utils.h" #include "xla/python/ifrt/support/sharding_conversions.h" #include "xla/service/hlo.pb.h" +#include "xla/xla_data.pb.h" namespace xla { namespace ifrt { @@ -388,10 +389,16 @@ IfrtCompileAndPropagateShardingsPass::GetInputShardingParams( if (llvm::isa( in_array_type.getShardingAttr())) { if (!in_shardings.has_value()) { - in_shardings = compile_result.executable->GetParameterShardings(); - if (!in_shardings.has_value()) { - return call_op.emitError() - << "executable does not have input shardings"; + if (call_op.getDevices().size() == 1) { + // Use replicated sharding for single-device inputs without calling + // `GetParameterShardings` since it may return `std::nullopt`. + in_shardings.emplace(call_op.getOutputs().size()); + } else { + in_shardings = compile_result.executable->GetParameterShardings(); + if (!in_shardings.has_value()) { + return call_op.emitError() + << "executable does not have input shardings"; + } } if (in_shardings->size() != call_op.getOutputs().size()) { return call_op.emitError() @@ -443,10 +450,16 @@ IfrtCompileAndPropagateShardingsPass::GetOutputShardingParams( if (llvm::isa( out_array_type.getShardingAttr())) { if (!out_shardings.has_value()) { - out_shardings = compile_result.executable->GetOutputShardings(); - if (!out_shardings.has_value()) { - return call_op.emitError() - << "executable does not have output shardings"; + if (call_op.getDevices().size() == 1) { + // Use replicated sharding for single-device inputs without calling + // `GetParameterShardings` since it may return `std::nullopt`. + out_shardings.emplace(call_op.getOutputs().size()); + } else { + out_shardings = compile_result.executable->GetOutputShardings(); + if (!out_shardings.has_value()) { + return call_op.emitError() + << "executable does not have output shardings"; + } } if (out_shardings->size() != call_op.getOutputs().size()) { return call_op.emitError() diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc index c62dc0cc98b897..fe0f68ccbea178 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -130,50 +131,49 @@ void IfrtVerifyBoundExternalLoadedExecutablePass::runOnOperation() { } auto func_type = loaded_exec_op.getFunctionType(); - if (!exec_it->second->GetParameterShardings().has_value()) { + std::optional> parameter_shardings; + if (loaded_exec_op.getDevices().size() == 1) { + parameter_shardings.emplace(func_type.getNumInputs()); + } else { + parameter_shardings = exec_it->second->GetParameterShardings(); + } + if (!parameter_shardings.has_value()) { return loaded_exec_op.emitOpError() << "cannot be bound to an executable without parameter " "shardings"; } - if (!exec_it->second->GetOutputShardings().has_value()) { + std::optional> output_shardings; + if (loaded_exec_op.getDevices().size() == 1) { + output_shardings.emplace(func_type.getNumResults()); + } else { + output_shardings = exec_it->second->GetOutputShardings(); + } + if (!output_shardings.has_value()) { return loaded_exec_op.emitOpError() - << "cannot be bound to an executable without output shardings"; + << "cannot be bound to a multi-device executable without output " + "shardings"; } - if (func_type.getNumInputs() != - exec_it->second->GetParameterShardings()->size()) { + if (func_type.getNumInputs() != parameter_shardings->size()) { return loaded_exec_op.emitOpError() << "expects an executable with " << func_type.getNumInputs() << " inputs, but was bound to an executable with " - << exec_it->second->GetParameterShardings()->size() << " inputs"; + << parameter_shardings->size() << " inputs"; } - if (func_type.getNumResults() != - exec_it->second->GetOutputShardings()->size()) { + if (func_type.getNumResults() != output_shardings->size()) { return loaded_exec_op.emitOpError() << "expects an executable with " << func_type.getNumResults() << " results, but was bound to an executable with " - << exec_it->second->GetOutputShardings()->size() << " results"; + << output_shardings->size() << " results"; } // Verify that the input and output shardings of the LoadedExecutableOp // are the same as the shardings of the bound executable. - if (!exec_it->second->GetParameterShardings().has_value()) { - return loaded_exec_op.emitOpError() - << "cannot be bound to an executable without parameter " - "shardings"; - } - if (!exec_it->second->GetOutputShardings().has_value()) { - return loaded_exec_op.emitOpError() - << "cannot be bound to an executable without output " - "shardings"; - } auto sharding_equal_status = VerifyShardingsEqual( - func_type.getInputs(), *exec_it->second->GetParameterShardings(), - "input"); + func_type.getInputs(), *parameter_shardings, "input"); if (!sharding_equal_status.ok()) { return loaded_exec_op.emitOpError() << sharding_equal_status.message(); } - sharding_equal_status = VerifyShardingsEqual( - func_type.getResults(), *exec_it->second->GetOutputShardings(), - "output"); + sharding_equal_status = VerifyShardingsEqual(func_type.getResults(), + *output_shardings, "output"); if (!sharding_equal_status.ok()) { return loaded_exec_op.emitOpError() << sharding_equal_status.message(); } From 5d9394b3fca77e1a85361b7b6177f855b9b96e36 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 09:53:58 -0800 Subject: [PATCH 037/753] [XLA:TPU] SPMD Partitioner should not change layout if no partition changes done PiperOrigin-RevId: 841806676 --- .../xla/xla/service/spmd/spmd_partitioner.cc | 115 +++++++++++------- 1 file changed, 71 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 3efed6ad73375a..052fe73912d8ba 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -5488,6 +5488,26 @@ int64_t SpmdPartitioner::CommunicationCostInBytes(HloInstruction* hlo) { module->set_spmd_output_sharding(entry_root->sharding()); } +namespace { + +// Returns true if the old and the new entry layout shapes differ. +// NOTE: that we explicitly ignore the layout, since it is either defined +// beforehand or during layout assignment. +bool ShapeChangesBetween(const ComputationLayout& old_entry_layout, + const ProgramShape& new_program_shape) { + for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) { + if (!Shape::Equal().IgnoreLayout()(old_entry_layout.parameter_shape(i), + new_program_shape.parameters(i))) { + return true; + } + } + + return !Shape::Equal().IgnoreLayout()(old_entry_layout.result_shape(), + new_program_shape.result()); +} + +} // namespace + absl::StatusOr SpmdPartitioner::RunImpl( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -5582,57 +5602,64 @@ absl::StatusOr SpmdPartitioner::RunImpl( })); // For the entry computation, make sure that the root instruction and the - // parameters preserve their signatures. + // parameters preserve their signatures if there are any partitioning changes. auto new_program_shape = module->entry_computation()->ComputeProgramShape(); - if (!options_.allow_module_signature_change) { - if (!Shape::Equal()(program_shape.result(), new_program_shape.result())) { - return absl::InvalidArgumentError( - "Result shape changed for the entry computation from: " + - program_shape.result().ToString() + - " to: " + new_program_shape.result().ToString()); - } - if (program_shape.parameters_size() != - new_program_shape.parameters_size()) { - return absl::InvalidArgumentError( - "Parameter count changed for the entry computation from: " + - std::to_string(program_shape.parameters_size()) + - " to: " + std::to_string(new_program_shape.parameters_size())); - } - for (int64_t i = 0; i < program_shape.parameters_size(); ++i) { - if (!Shape::Equal()(program_shape.parameters(i), - new_program_shape.parameters(i))) { + const ComputationLayout& old_entry_layout = + module->entry_computation_layout(); + if (ShapeChangesBetween(old_entry_layout, new_program_shape)) { + if (!options_.allow_module_signature_change) { + if (!Shape::Equal()(program_shape.result(), new_program_shape.result())) { return absl::InvalidArgumentError( - "Parameter shape changed for the entry computation parameter " + - std::to_string(i) + - " from: " + program_shape.parameters(i).ToString() + - " to: " + new_program_shape.parameters(i).ToString()); + "Result shape changed for the entry computation from: " + + program_shape.result().ToString() + + " to: " + new_program_shape.result().ToString()); } - } - } else { - // Fix up some bad tiling in entry computation layout. - auto update_shape = [this](Shape* subshape, const xla::ShapeIndex& index) { - if (subshape->IsArray() && subshape->has_layout()) { - UpdateLayout(subshape); + if (program_shape.parameters_size() != + new_program_shape.parameters_size()) { + return absl::InvalidArgumentError( + "Parameter count changed for the entry computation from: " + + std::to_string(program_shape.parameters_size()) + + " to: " + std::to_string(new_program_shape.parameters_size())); } - }; - const auto& old_entry_layout = module->entry_computation_layout(); - // Shapes can change but the layout should still remain the same. - for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) { + for (int64_t i = 0; i < program_shape.parameters_size(); ++i) { + if (!Shape::Equal()(program_shape.parameters(i), + new_program_shape.parameters(i))) { + return absl::InvalidArgumentError( + "Parameter shape changed for the entry computation parameter " + + std::to_string(i) + + " from: " + program_shape.parameters(i).ToString() + + " to: " + new_program_shape.parameters(i).ToString()); + } + } + } else { + // For the cases where we update the shape, also fix up some bad tiling in + // entry computation layout. + auto update_shape = [this](Shape* subshape, + const xla::ShapeIndex& index) { + if (subshape->IsArray() && subshape->has_layout()) { + UpdateLayout(subshape); + } + }; + // Shapes can change but the layout should still remain the same. + // If the shapes do not change, we shouldn't change the layout if pre-set. + for (int64_t i = 0; i < new_program_shape.parameters_size(); ++i) { + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + old_entry_layout.parameter_shape(i), + new_program_shape.mutable_parameters(i))); + ShapeUtil::ForEachMutableSubshape( + new_program_shape.mutable_parameters(i), update_shape); + } + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - old_entry_layout.parameter_shape(i), - new_program_shape.mutable_parameters(i))); - ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_parameters(i), + old_entry_layout.result_shape(), new_program_shape.mutable_result())); + ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_result(), update_shape); + + HloModuleConfig config = module->config(); + *config.mutable_entry_computation_layout() = + ComputationLayout(new_program_shape, /*ignore_layouts=*/false); + module->set_config(config); } - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - old_entry_layout.result_shape(), new_program_shape.mutable_result())); - ShapeUtil::ForEachMutableSubshape(new_program_shape.mutable_result(), - update_shape); - - HloModuleConfig config = module->config(); - *config.mutable_entry_computation_layout() = - ComputationLayout(new_program_shape, /*ignore_layouts=*/false); - module->set_config(config); } XLA_VLOG_LINES(1, SpmdLogger::ReportAfterPartition( From 3764d7a4a80b4f8d4a97c4beeaa11883465acc35 Mon Sep 17 00:00:00 2001 From: Yin Zhang Date: Mon, 8 Dec 2025 09:55:37 -0800 Subject: [PATCH 038/753] Change OpSourceInfo::source_file from absl::string_view to std::string to prevent dangling references, as we will switch to prioritize parsing the file_name from stack_frame. (Since stack_frame is a string, so parsed file_name will be a temp string.) PiperOrigin-RevId: 841807283 --- third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h index 601a43b564a3c6..bc074d9bbbd3a2 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h +++ b/third_party/xla/xla/tsl/profiler/convert/xla_op_utils.h @@ -173,7 +173,7 @@ inline bool IsOffDutyOp(absl::string_view category) { // to in a user's program; e.g. it could be the file and line of user code that // generated the op. struct OpSourceInfo { - absl::string_view source_file; + std::string source_file; int32_t source_line = -1; std::string stack_frame; From bf24158520d842ad494969edf868a3671f7f6257 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 8 Dec 2025 10:56:46 -0800 Subject: [PATCH 039/753] Implement a new two-stage `ProgramInterpreter` design for better efficiency and safety This CL implements a new design for `ProgramInterpreter` with a goal of separating work that can be done once per program vs. work that needs to be done per program invocation. The current `ProgramInterpreter` design iterates over the compiled IFRT IR program and invokes IFRT APIs on demand. In this design, the interpreter needs to convert MLIR types into IFRT types, perform validation, etc. during every execution, which is wasteful since such information does not change. The new design avoids the aforementioned problem by splitting the program interpreter into two stages. First, `ProgramInterpreter::BuildExecuteFn()` now traverses the program and *returns a function that can be invoked to run the program*. The execute function is built only once during compilation and can perform any work that only needs static information, e.g., building `xla::ifrt::RemapPlan` from a `RemapArraysOp` MLIR op. Once that is complete, each program invocation just needs to call the execute function produced by the program interpreter. `CompiledIfrtIRProgram` now carries this "compiled" execute functions so that this can be invoked by executables. This makes `CompiledIfrtIRProgram::program` optional since we no longer need to carry the MLIR module to execute an IFRT IR program. This can save host memory if there are a large number of IFRT IR programs. PiperOrigin-RevId: 841834380 --- third_party/xla/xla/python/ifrt/ir/BUILD | 6 +- .../ifrt/ir/compiled_ifrt_ir_program.cc | 14 +- .../python/ifrt/ir/compiled_ifrt_ir_program.h | 14 + .../xla/python/ifrt/ir/program_interpreter.cc | 927 +++++++++++------- .../xla/python/ifrt/ir/program_interpreter.h | 85 +- 5 files changed, 633 insertions(+), 413 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index c48becc712bcb4..8f35b34468e5b8 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -507,6 +507,7 @@ cc_library( ":atom_program_compiler", ":ifrt_ir_program", ":ir", + ":program_interpreter", "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_layout", @@ -520,6 +521,7 @@ cc_library( "//xla/tsl/platform:statusor", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -542,7 +544,7 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:users"], deps = [ - ":compiled_ifrt_ir_program", + ":atom_program_compiler", ":ir", "//xla:status_macros", "//xla/python/ifrt", @@ -555,6 +557,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc index b3ec80c3d6f3e9..67c912b86cafc2 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc +++ b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/program_interpreter.h" #include "xla/python/ifrt/ir/transforms/debug.h" #include "xla/python/ifrt/ir/transforms/passes.h" #include "xla/python/ifrt/ir/transforms/utils.h" @@ -327,6 +328,8 @@ absl::StatusOr CompiledIfrtIrProgram::Create( mlir::MLIRContext* context = mlir_module.getContext(); xla::ifrt::support::RegisterMlirDialects(*context); + std::string program_name = mlir_module.getName().value_or("unknown").str(); + // Add the bounded executables to the atom program executable map so that // they can be used by the interpreter std::shared_ptr atom_executable_map = @@ -434,8 +437,16 @@ absl::StatusOr CompiledIfrtIrProgram::Create( } } + TF_ASSIGN_OR_RETURN(DeviceListRef device_list, + client->MakeDeviceList(devices)); + TF_ASSIGN_OR_RETURN( + auto interpreter, + ProgramInterpreter::Create(client, program_name, mlir_module, + atom_executable_map, std::move(device_list))); + TF_ASSIGN_OR_RETURN(auto execute_fn, interpreter->BuildExecuteFn()); + return CompiledIfrtIrProgram{ - /*program_name=*/mlir_module.getName().value_or("unknown").str(), + /*program_name=*/std::move(program_name), /*atom_program_executables=*/std::move(atom_executable_map), /*in_specs=*/std::move(in_specs), /*out_specs=*/std::move(out_specs), @@ -444,6 +455,7 @@ absl::StatusOr CompiledIfrtIrProgram::Create( /*program=*/std::move(ifrt_ir_program), /*device_assignments=*/std::move(device_assignments), /*compile_options=*/compile_options, + /*execute_fn=*/std::move(execute_fn), }; } diff --git a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.h b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.h index 509750627489e1..c9baf2b35b5b8c 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.h +++ b/third_party/xla/xla/python/ifrt/ir/compiled_ifrt_ir_program.h @@ -16,14 +16,20 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_IR_COMPILED_IFRT_IR_PROGRAM_H_ #define XLA_PYTHON_IFRT_IR_COMPILED_IFRT_IR_PROGRAM_H_ #include +#include #include #include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/array.h" #include "xla/python/ifrt/array_spec.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/ir/atom_program_compiler.h" #include "xla/python/ifrt/ir/ifrt_ir_program.h" @@ -63,6 +69,14 @@ struct CompiledIfrtIrProgram { // The compile options used to compile the program. std::shared_ptr compile_options; + // Precompiled execute function that interprets the IFRT IR program. The + // signature matches that of `xla::ifrt::LoadedExecutable::Execute()`. + absl::AnyInvocable( + absl::Span arrays, + const xla::ifrt::LoadedExecutable::ExecuteOptions& options, + std::optional devices)> + execute_fn; + // Compiles an IFRT IR program. static absl::StatusOr Create( std::unique_ptr ifrt_ir_program, diff --git a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc index 7e8612f830303f..deddd328048a7d 100644 --- a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc +++ b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/python/ifrt/ir/program_interpreter.h" +#include #include #include #include @@ -24,18 +25,21 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/functional/bind_front.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Support/DebugStringHelper.h" @@ -46,7 +50,7 @@ limitations under the License. #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/ifrt/ir/compiled_ifrt_ir_program.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" #include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/ifrt_ops.h" @@ -74,12 +78,20 @@ using ExecuteResult = ::xla::ifrt::LoadedExecutable::ExecuteResult; namespace { +// Opaque handle that represents an array. Zero is reserved for null. +using ArrayHandle = uintptr_t; + // Array with additional metadata (e.g., if it can be donated). struct ArrayState { ArrayRef array; bool can_be_donated; }; +// Assigns a unique handle to the given MLIR value. +ArrayHandle ToArrayHandle(mlir::Value value) { + return reinterpret_cast(value.getAsOpaquePointer()); +} + // Returns an xla::ifrt::Sharding for the given IFRT array type. absl::StatusOr GetSharding( xla::ifrt::IfrtArrayType array_type, xla::ifrt::Client* client, @@ -110,65 +122,23 @@ std::string PrettyPrintGeneric(mlir::Operation* op) { GetPrettyLocation(op->getLoc())); } -// Populates the cache storing a Sharding for each IfrtArrayType. -// -// This cache exists to avoid traversing and creating large device lists at -// execution time. -// -// Note that the cache is only populated for array types returned by CopyArrays -// and RemapArrays ops because they are the only ops that need shardings. -absl::StatusOr> -PopulateShardingCache(mlir::func::FuncOp main_func, xla::ifrt::Client* client, - const xla::ifrt::DeviceListRef& devices) { - llvm::DenseMap - array_type_to_sharding; - for (const mlir::Operation& op : main_func.getOps()) { - if (auto copy_arrays_op = llvm::dyn_cast(&op); - copy_arrays_op != nullptr) { - for (const auto [idx, output] : - llvm::enumerate(copy_arrays_op.getOutputs())) { - const auto array_type = - llvm::cast(output.getType()); - TF_RET_CHECK(array_type != nullptr) - << "Output array #" << idx << " is not of type `IfrtArrayType`. " - << PrettyPrintGeneric(copy_arrays_op); - if (array_type_to_sharding.find(array_type) == - array_type_to_sharding.end()) { - TF_ASSIGN_OR_RETURN(auto sharding, - GetSharding(array_type, client, devices)); - array_type_to_sharding[array_type] = std::move(sharding); - } - } - } else if (auto remap_op = llvm::dyn_cast(&op); - remap_op != nullptr) { - for (const auto [idx, output] : llvm::enumerate(remap_op.getOutputs())) { - const auto array_type = - llvm::cast(output.getType()); - TF_RET_CHECK(array_type != nullptr) - << "Output array #" << idx << " is not of type `IfrtArrayType`. " - << PrettyPrintGeneric(remap_op); - if (array_type_to_sharding.find(array_type) == - array_type_to_sharding.end()) { - TF_ASSIGN_OR_RETURN(auto sharding, - GetSharding(array_type, client, devices)); - array_type_to_sharding[array_type] = std::move(sharding); - } - } - } - } - return array_type_to_sharding; -} - } // namespace struct Environment { - // Associates array with an MLIR value. - void AssociateArray(mlir::Value value, ArrayState array) { - CHECK(value_to_array.try_emplace(value, array).second); + // Associates array with an opaque handle. + void AssociateArray(ArrayHandle handle, ArrayState array) { + CHECK(handle_to_array.try_emplace(handle, array).second); } - // Map from MLIR value to IFRT array corresponding to the value. - llvm::DenseMap value_to_array; + // IFRT client for execution. + xla::ifrt::Client* client; + // Name of the program. + std::string program_name; + // Set of donated program arguments, which can be deleted after their last + // use. Entries are removed upon deletion or if they are aliased. + absl::flat_hash_set deletable_program_arguments; + // Map from an opaque handle to IFRT array corresponding to the value. + absl::flat_hash_map handle_to_array; // Outputs of the program. std::vector outputs; // `ExecuteOptions.fill_status` passed to Execute(). @@ -179,213 +149,401 @@ struct Environment { }; absl::StatusOr> ProgramInterpreter::Create( - xla::ifrt::Client* client, std::shared_ptr program, + xla::ifrt::Client* client, absl::string_view program_name, + mlir::ModuleOp mlir_module, + std::shared_ptr atom_program_executables, xla::ifrt::DeviceListRef devices) { - mlir::func::FuncOp main_func = - xla::ifrt::GetMainFunction(program->program->mlir_module); + mlir::func::FuncOp main_func = xla::ifrt::GetMainFunction(mlir_module); if (!main_func->hasAttr(xla::ifrt::kIfrtFunctionAttrName)) { - return absl::InvalidArgumentError(absl::StrCat( - "`main` function of IFRT IR program: ", program->program_name, - " is not an IFRT function.")); + return absl::InvalidArgumentError( + absl::StrCat("`main` function of IFRT IR program: ", program_name, + " is not an IFRT function.")); } - TF_ASSIGN_OR_RETURN(auto array_type_to_sharding, - PopulateShardingCache(main_func, client, devices)); return std::unique_ptr(new ProgramInterpreter( - client, std::move(program), std::move(devices), mlir::Liveness(main_func), - std::move(array_type_to_sharding))); + client, program_name, mlir_module, std::move(atom_program_executables), + std::move(devices), mlir::Liveness(main_func))); } -absl::StatusOr ProgramInterpreter::Execute( - absl::Span arrays, const ExecuteOptions& options, - std::optional devices) { - TraceMe traceme([&]() { - return TraceMeEncode("DispatchProgram", - { - {"ifrt_ir_program", program_->program_name}, - }); - }); - VLOG(2) << "Started interpreting program: " << program_->program_name; - mlir::func::FuncOp main_func = - xla::ifrt::GetMainFunction(program_->program->mlir_module); - if (arrays.size() != main_func.getNumArguments()) { - return absl::InvalidArgumentError(absl::StrCat( - "`main` function of IFRT IR program: ", program_->program_name, - " invoked with ", arrays.size(), " arguments, but it expects ", - main_func.getNumArguments(), " arguments.")); - } +namespace { + +struct ProgramInterpreterState { + xla::ifrt::Client* client; + std::string program_name; - for (const auto& [idx, array] : llvm::enumerate(arrays)) { - if (array->IsDeleted()) { + std::vector input_handles; + absl::flat_hash_set donated_input_indices; + + std::vector> op_fns; + + absl::StatusOr Run( + absl::Span arrays, + const xla::ifrt::LoadedExecutable::ExecuteOptions& options, + std::optional devices) const { + TraceMe traceme([&]() { + return TraceMeEncode("DispatchProgram", + {{"ifrt_ir_program", program_name}}); + }); + VLOG(2) << "Started interpreting program: " << program_name; + + if (arrays.size() != input_handles.size()) { return absl::InvalidArgumentError(absl::StrCat( - "Input array #", idx, " of program ", program_->program_name, - " has already been deleted or donated.")); + "`main` function of IFRT IR program: ", program_name, + " invoked with ", arrays.size(), " arguments, but it expects ", + input_handles.size(), " arguments.")); } - } - Environment env; - env.fill_status = options.fill_status; + for (int idx = 0; idx < arrays.size(); ++idx) { + const xla::ifrt::ArrayRef& array = arrays[idx]; + if (array->IsDeleted()) { + return absl::InvalidArgumentError( + absl::StrCat("Input array #", idx, " of program ", program_name, + " has already been deleted or donated.")); + } + } + + Environment env; + env.client = client; + env.fill_status = options.fill_status; + for (int idx = 0; idx < input_handles.size(); ++idx) { + // Add to the environment the arrays that are used. + bool is_donated = donated_input_indices.contains(idx) && + !options.non_donatable_input_indices.contains(idx); + const ArrayHandle handle = input_handles[idx]; + if (handle != 0) { + env.AssociateArray(handle, ArrayState{ + /*array=*/arrays[idx], + /*can_be_donated=*/is_donated, + }); + if (is_donated) { + env.deletable_program_arguments.insert(handle); + } + } else if (is_donated) { + // If the argument is donated but not used, it can be deleted. + arrays[idx]->Delete(); + } + } + + for (const auto& op_fn : op_fns) { + TF_RETURN_IF_ERROR(op_fn(env)); + } + + VLOG(2) << "Finished interpreting program: " << program_name; + ExecuteResult result; + if (env.fill_status) { + result.status = + tsl::JoinFutures(absl::MakeSpan(env.leaf_call_op_futures)); + } + result.outputs = std::move(env.outputs); + return result; + }; +}; + +} // namespace + +absl::StatusOr +ProgramInterpreter::BuildExecuteFn() { + ProgramInterpreterState state; + state.client = client_; + state.program_name = program_name_; + + mlir::func::FuncOp main_func = xla::ifrt::GetMainFunction(mlir_module_); + for (const auto [idx, arg] : llvm::enumerate(main_func.getArguments())) { // Add to the environment the arrays that are used. - bool is_donated = main_func.getArgAttr( - idx, xla::ifrt::kIfrtDonatedArgAttrName) != nullptr && - !options.non_donatable_input_indices.contains(idx); - if (!arg.use_empty()) { - env.AssociateArray(arg, ArrayState{/*array=*/arrays[idx], - /*can_be_donated=*/is_donated}); - if (is_donated) { - deletable_program_arguments_.insert(arg); - } - } else if (is_donated) { - // If the argument is donated but not used, it can be deleted. - arrays[idx]->Delete(); + const ArrayHandle handle = arg.use_empty() ? 0 : ToArrayHandle(arg); + state.input_handles.push_back(handle); + if (main_func.getArgAttr(idx, xla::ifrt::kIfrtDonatedArgAttrName) != + nullptr) { + state.donated_input_indices.insert(idx); } } - // Walk ops one-by-one in program order, and dispatch atom program and - // copy arrays. + // Walk ops one-by-one in program order and create functions that execute each + // op on a given environment. for (mlir::Operation& op : main_func.getOps()) { - auto exec_op_status = - llvm::TypeSwitch(op) + auto op_fn = + llvm::TypeSwitch>(op) .Case( - [&](const auto& op) { return ExecuteOp(op, env); }) - .Default([&](const auto& op) { + [this](const auto& op) { return HandleOp(op); }) + .Default([](const mlir::Operation& op) { return absl::InvalidArgumentError(absl::StrCat( "Interpreter found unexpected op: ", mlir::debugString(op))); }); - if (!exec_op_status.ok()) { - tsl::errors::AppendToMessage(&exec_op_status, PrettyPrint(&op)); - return exec_op_status; + if (!op_fn.ok()) { + absl::Status status = op_fn.status(); + tsl::errors::AppendToMessage(&status, PrettyPrint(&op)); + return status; } + state.op_fns.push_back( + [op_fn = *std::move(op_fn), + pretty_print = PrettyPrint(&op)](Environment& env) -> absl::Status { + absl::Status status = op_fn(env); + tsl::errors::AppendToMessage(&status, pretty_print); + return status; + }); } - VLOG(2) << "Finished interpreting program: " << program_->program_name; - ExecuteResult result; - if (env.fill_status) { - result.status = tsl::JoinFutures(absl::MakeSpan(env.leaf_call_op_futures)); - } - result.outputs = std::move(env.outputs); - return result; + return absl::bind_front(&ProgramInterpreterState::Run, std::move(state)); } -absl::Status ProgramInterpreter::ExecuteOp( - xla::ifrt::CallLoadedExecutableOp call_loaded_op, Environment& env) { +namespace { + +struct CallLoadedExecutableOpState { + std::string pretty_print; + std::string atom_program_name; + + std::vector input_handles; + absl::flat_hash_set donated_arg_idxs; + absl::flat_hash_set dead_inputs; + + xla::ifrt::LoadedExecutable::ExecuteOptions execute_options; + std::shared_ptr executable; + + std::vector output_handles; + bool is_leaf_op; + + absl::Status Run(Environment& env) const { + TraceMe traceme([&]() { + return TraceMeEncode("DispatchLoadedExecutableOp", + { + {"ifrt_ir_program", env.program_name}, + {"atom_program", atom_program_name}, + }); + }); + VLOG(3) << pretty_print; + + xla::ifrt::LoadedExecutable::ExecuteOptions options = execute_options; + options.fill_status = env.fill_status; + + // Get the inputs of the loaded executable. + std::vector inputs; + std::vector arrays_to_remove; + for (int idx = 0; idx < input_handles.size(); ++idx) { + const ArrayHandle handle = input_handles[idx]; + + auto array_it = env.handle_to_array.find(handle); + TF_RET_CHECK(array_it != env.handle_to_array.end()) + << "Input array #" << idx << " not found. " << pretty_print; + if (array_it->second.array->IsDeleted()) { + // We explicitly check here for deletion in order to provide a more + // informative error message. + return absl::InvalidArgumentError(absl::StrCat( + "Input array #", idx, "` has already been deleted or donated. ", + pretty_print)); + } + inputs.push_back(array_it->second.array); + + bool is_donated = donated_arg_idxs.contains(idx); + if (is_donated && !array_it->second.can_be_donated) { + VLOG(2) << "Atom program donates input #" << idx + << ", but it has not been donated to the IFRT IR program. " + "Input will not be donated. \n" + << pretty_print; + is_donated = false; + } + if (is_donated || dead_inputs.contains(handle)) { + arrays_to_remove.push_back(handle); + } + if (!is_donated) { + options.non_donatable_input_indices.insert(idx); + } + } + + TF_ASSIGN_OR_RETURN(xla::ifrt::LoadedExecutable::ExecuteResult result, + executable->Execute(absl::MakeSpan(inputs), options, + /*devices=*/std::nullopt)); + TF_RET_CHECK(result.outputs.size() == output_handles.size()) + << "Got " << result.outputs.size() << " results, but atom program has " + << output_handles.size() << ". " << pretty_print; + + // Remove the arrays from the environment after the inputs vector is + // created. This is because in situations such as `ifrt.Call(%0, %0)` the + // liveness analysis will return that %0 is dead, but it's used for the + // second argument. + for (const auto handle : arrays_to_remove) { + if (env.deletable_program_arguments.erase(handle)) { + // Explicitly delete donated program arguments that are not used later. + env.handle_to_array[handle].array->Delete(); + } + env.handle_to_array.erase(handle); + } + + for (int i = 0; i < output_handles.size(); ++i) { + const ArrayHandle handle = output_handles[i]; + if (handle != 0) { + // The output array is kept only if it used later. This can happen if an + // executable has multiple output arrays, but only some of them are + // used. + env.AssociateArray(handle, ArrayState{ + /*array=*/std::move(result.outputs[i]), + /*can_be_donated=*/true, + }); + } + } + if (is_leaf_op && env.fill_status) { + env.leaf_call_op_futures.push_back(std::move(result.status)); + } + return absl::OkStatus(); + } +}; + +} // namespace + +absl::StatusOr ProgramInterpreter::HandleOp( + xla::ifrt::CallLoadedExecutableOp call_loaded_op) { + CallLoadedExecutableOpState state; + state.pretty_print = PrettyPrint(call_loaded_op); + xla::ifrt::LoadedExecutableOp loaded_exec_op = call_loaded_op.getCalleeOp(symbol_table_); - std::string atom_program_name = loaded_exec_op.getSymName().str(); - TraceMe traceme([&]() { - return TraceMeEncode("DispatchLoadedExecutableOp", - { - {"ifrt_ir_program", program_->program_name}, - {"atom_program", atom_program_name}, - }); - }); - std::string op_name = call_loaded_op->getName().getStringRef().str(); - VLOG(3) << PrettyPrint(call_loaded_op); + state.atom_program_name = loaded_exec_op.getSymName().str(); + // Get the loaded executable for the atom program. - auto exec_it = program_->atom_program_executables->find(atom_program_name); - TF_RET_CHECK(exec_it != program_->atom_program_executables->end()) - << "Could not find executable. " << PrettyPrint(call_loaded_op); + auto exec_it = atom_program_executables_->find(state.atom_program_name); + TF_RET_CHECK(exec_it != atom_program_executables_->end()) + << "Could not find executable. " << state.pretty_print; + state.executable = exec_it->second; - absl::flat_hash_set donated_arg_idxs( - call_loaded_op.getDonatedInputIndices().begin(), - call_loaded_op.getDonatedInputIndices().end()); + state.donated_arg_idxs.insert(call_loaded_op.getDonatedInputIndices().begin(), + call_loaded_op.getDonatedInputIndices().end()); for (const auto& io_alias : call_loaded_op.getIoAliases().getAsRange()) { // Insert the aliased input to the set. - donated_arg_idxs.insert(io_alias.asArrayRef()[0]); + state.donated_arg_idxs.insert(io_alias.asArrayRef()[0]); } - // Get the inputs of the loaded executable. - std::vector inputs; - xla::ifrt::LoadedExecutable::ExecuteOptions execute_options; - execute_options.fill_status = env.fill_status; - llvm::DenseSet array_values_to_gc_from_env; - for (const auto [idx, input] : llvm::enumerate(call_loaded_op.getInputs())) { - auto array_it = env.value_to_array.find(input); - TF_RET_CHECK(array_it != env.value_to_array.end()) - << "Input array #" << idx << " not found. " - << PrettyPrint(call_loaded_op); - if (array_it->second.array->IsDeleted()) { - // We explicitly check here for deletion in order to provide a more - // informative error message. - return absl::InvalidArgumentError(absl::StrCat( - "Input array #", idx, "` has already been deleted or donated. ", - PrettyPrint(call_loaded_op))); - } - inputs.push_back(array_it->second.array); - - bool is_donated = donated_arg_idxs.contains(idx); - if (is_donated && !array_it->second.can_be_donated) { - VLOG(2) << "Atom program donates input #" << idx - << ", but it has not been donated to the IFRT IR program. " - "Input will not be donated. \n" - << PrettyPrint(call_loaded_op); - is_donated = false; - } - if (is_donated || liveness_.isDeadAfter(input, call_loaded_op)) { - array_values_to_gc_from_env.insert(input); - } - if (!is_donated) { - execute_options.non_donatable_input_indices.insert(idx); + for (const auto input : call_loaded_op.getInputs()) { + state.input_handles.push_back(ToArrayHandle(input)); + if (liveness_.isDeadAfter(input, call_loaded_op)) { + state.dead_inputs.insert(ToArrayHandle(input)); } } - TF_ASSIGN_OR_RETURN( - xla::ifrt::LoadedExecutable::ExecuteResult result, - exec_it->second->Execute(absl::MakeSpan(inputs), execute_options, - /*devices=*/std::nullopt)); - TF_RET_CHECK(result.outputs.size() == call_loaded_op.getOutputs().size()) - << "Got " << result.outputs.size() << " results, but atom program has " - << call_loaded_op.getOutputs().size() << ". " - << PrettyPrint(call_loaded_op); - - // Remove the arrays from the environment after the inputs vector is created. - // This is because in situations such as `ifrt.Call(%0, %0)` the liveness - // analysis will return that %0 is dead, but it's used for the second - // argument. - for (const auto& array_value : array_values_to_gc_from_env) { - if (deletable_program_arguments_.erase(array_value)) { - // Explicitly delete donated program arguments that are not used later. - env.value_to_array[array_value].array->Delete(); - } - env.value_to_array.erase(array_value); - } + state.is_leaf_op = true; + for (const auto output : call_loaded_op.getOutputs()) { + const ArrayHandle handle = output.use_empty() ? 0 : ToArrayHandle(output); + state.output_handles.push_back(handle); - bool is_leaf_op = true; - for (const auto [output_array, output] : - llvm::zip(result.outputs, call_loaded_op.getOutputs())) { - if (!output.use_empty()) { - // The output array is kept only if it used later. This can happen if - // an executable has multiple output arrays, but only some of them are - // used. - env.AssociateArray(output, ArrayState{/*array=*/std::move(output_array), - /*can_be_donated=*/true}); - } - if (is_leaf_op) { + if (state.is_leaf_op) { for (mlir::OpOperand& use : output.getUses()) { // An ifrt.CallOp is not a leaf if any of its outputs are not returned. if (llvm::dyn_cast(use.getOwner()) == nullptr) { - is_leaf_op = false; + state.is_leaf_op = false; break; } } } } - if (is_leaf_op && env.fill_status) { - env.leaf_call_op_futures.push_back(std::move(result.status)); - } - return absl::OkStatus(); + return absl::bind_front(&CallLoadedExecutableOpState::Run, std::move(state)); } -absl::Status ProgramInterpreter::ExecuteOp(xla::ifrt::RemapArraysOp remap_op, - Environment& env) { - TraceMe traceme([&]() { - return TraceMeEncode("DispatchRemapArraysOp", - {{"ifrt_ir_program", program_->program_name}}); - }); - std::string op_name = remap_op->getName().getStringRef().str(); - VLOG(3) << PrettyPrint(remap_op); +namespace { + +struct RemapArraysOpState { + std::string pretty_print; + + xla::ifrt::RemapPlan remap_plan; + std::vector input_handles; + absl::flat_hash_set dead_inputs; + bool remap_is_donated; + + std::vector output_handles; + + absl::Status Run(Environment& env) const { + TraceMe traceme([&]() { + return TraceMeEncode("DispatchRemapArraysOp", + {{"ifrt_ir_program", env.program_name}}); + }); + VLOG(3) << pretty_print; + + std::vector inputs; + inputs.reserve(remap_plan.input_specs.size()); + + std::optional is_donated; + std::vector arrays_to_remove; + + for (int idx = 0; idx < input_handles.size(); ++idx) { + const ArrayHandle handle = input_handles[idx]; + + auto array_it = env.handle_to_array.find(handle); + TF_RET_CHECK(array_it != env.handle_to_array.end()) + << "Input array #" << idx << " not found. " << pretty_print; + if (array_it->second.array->IsDeleted()) { + // We explicitly check here for deletion in order to provide a more + // informative error message. + return absl::InvalidArgumentError(absl::StrCat( + "Input array #", idx, "` has already been deleted or donated. ", + pretty_print)); + } + inputs.push_back(array_it->second.array); + + // The default buffer donation semantic is finalized at compilation time. + // Users can override the donation semantic at runtime. In the meantime, + // the IFRT client RemapArrays API requires all input arrays have the same + // donation semantic. + if (!is_donated.has_value()) { + is_donated = remap_is_donated && array_it->second.can_be_donated; + } + if (*is_donated && !array_it->second.can_be_donated) { + return absl::InvalidArgumentError(absl::StrCat( + "Donation semantic must be consistent across all input arrays of " + "RemapArraysOp. Input array #", + idx, + " cannot be donated, but previous input arrays can be donated. " + "It's likely due to a MPMD program argument is marked as " + "non-donatable. ", + pretty_print)); + } + if (*is_donated || dead_inputs.contains(handle)) { + arrays_to_remove.push_back(handle); + } + } + TF_RET_CHECK(is_donated.has_value()) + << "Unable to determine the donation semantic of the remap op. The " + "remap op has no inputs. " + << pretty_print; + + // Apply the remap arrays operation. + xla::ifrt::ArrayCopySemantics copy_semantics = + *is_donated ? xla::ifrt::ArrayCopySemantics::kDonateInput + : xla::ifrt::ArrayCopySemantics::kReuseInput; + TF_ASSIGN_OR_RETURN(auto out_arrays, env.client->RemapArrays( + remap_plan, absl::MakeSpan(inputs), + copy_semantics)); + + for (const auto handle : arrays_to_remove) { + // Donated remapped arrays are pro-actively deleted, and aliased arrays + // cannot be deleted later. Thus, remove the arrays from the deletable + // program arguments set. + env.deletable_program_arguments.erase(handle); + env.handle_to_array.erase(handle); + } + + // Store the result arrays in the environment. + TF_RET_CHECK(out_arrays.size() == remap_plan.output_specs.size()) + << "Got " << out_arrays.size() << " results, but op has " + << remap_plan.output_specs.size() << ". " << pretty_print; + for (int i = 0; i < output_handles.size(); ++i) { + const ArrayHandle handle = output_handles[i]; + if (handle != 0) { + env.AssociateArray(handle, ArrayState{ + /*array=*/std::move(out_arrays[i]), + /*can_be_donated=*/true, + }); + } + } + + return absl::OkStatus(); + } +}; + +} // namespace + +absl::StatusOr ProgramInterpreter::HandleOp( + xla::ifrt::RemapArraysOp remap_op) { + RemapArraysOpState state; + state.pretty_print = PrettyPrint(remap_op); // Construct the mappings of the remap plan. auto mappings = @@ -410,54 +568,28 @@ absl::Status ProgramInterpreter::ExecuteOp(xla::ifrt::RemapArraysOp remap_op, } }; - std::vector inputs; - std::vector input_specs; - inputs.reserve(remap_op.getInputs().size()); - input_specs.reserve(remap_op.getInputs().size()); // Get the input specs of the remap plan and the input arrays. - llvm::DenseSet array_values_to_gc_from_env; - std::optional is_donated; + std::vector input_specs; + input_specs.reserve(remap_op.getOutputs().size()); for (const auto [idx, input] : llvm::enumerate(remap_op.getInputs())) { - auto array_it = env.value_to_array.find(input); - TF_RET_CHECK(array_it != env.value_to_array.end()) - << "Input array #" << idx << " not found. " << PrettyPrint(remap_op); - if (array_it->second.array->IsDeleted()) { - // We explicitly check here for deletion in order to provide a more - // informative error message. - return absl::InvalidArgumentError(absl::StrCat( - "Input array #", idx, " has already been deleted or donated. ", - PrettyPrint(remap_op))); - } - inputs.push_back(array_it->second.array); + state.input_handles.push_back(ToArrayHandle(input)); + + const auto array_type = + llvm::cast(input.getType()); + TF_ASSIGN_OR_RETURN( + xla::ifrt::DType dtype, + xla::ifrt::ToIfrtDType(array_type.getShape().getElementType())); + TF_ASSIGN_OR_RETURN(xla::ifrt::ShardingRef sharding, + GetSharding(array_type, client_, devices_)); input_specs.push_back(xla::ifrt::ArraySpec{ - /*dtype=*/array_it->second.array->dtype(), - /*shape=*/array_it->second.array->shape(), - /*sharding=*/array_it->second.array->shared_ptr_sharding()}); - - // The default buffer donation semantic is finalized at compilation time. - // Users can override the donation semantic at runtime. In the meantime, the - // IFRT client RemapArrays API requires all input arrays have the same - // donation semantic. - if (!is_donated.has_value()) { - is_donated = remap_op.getDonated() && array_it->second.can_be_donated; - } - if (*is_donated && !array_it->second.can_be_donated) { - return absl::InvalidArgumentError(absl::StrCat( - "Donation semantic must be consistent across all input arrays of " - "RemapArraysOp. Input array #", - idx, - " cannot be donated, but previous input arrays can be donated. It's " - "likely due to a MPMD program argument is marked as non-donatable. ", - PrettyPrint(remap_op))); - } - if (*is_donated || liveness_.isDeadAfter(input, remap_op)) { - array_values_to_gc_from_env.insert(input); + /*dtype=*/dtype, + /*shape=*/xla::ifrt::Shape(array_type.getShape().getShape()), + /*sharding=*/std::move(sharding)}); + + if (liveness_.isDeadAfter(input, remap_op)) { + state.dead_inputs.insert(ToArrayHandle(input)); } } - TF_RET_CHECK(is_donated.has_value()) - << "Unable to determine the donation semantic of the remap op. The remap " - "op has no inputs. " - << PrettyPrint(remap_op); // Get the output specs of the remap plan. std::vector output_specs; @@ -468,153 +600,196 @@ absl::Status ProgramInterpreter::ExecuteOp(xla::ifrt::RemapArraysOp remap_op, TF_ASSIGN_OR_RETURN( xla::ifrt::DType dtype, xla::ifrt::ToIfrtDType(array_type.getShape().getElementType())); + TF_ASSIGN_OR_RETURN(xla::ifrt::ShardingRef sharding, + GetSharding(array_type, client_, devices_)); output_specs.push_back(xla::ifrt::ArraySpec{ /*dtype=*/dtype, /*shape=*/xla::ifrt::Shape(array_type.getShape().getShape()), - /*sharding=*/array_type_to_sharding_.at(array_type)}); + /*sharding=*/std::move(sharding)}); } - // Apply the remap arrays operation. - xla::ifrt::ArrayCopySemantics copy_semantics = - *is_donated ? xla::ifrt::ArrayCopySemantics::kDonateInput - : xla::ifrt::ArrayCopySemantics::kReuseInput; - TF_ASSIGN_OR_RETURN( - auto out_arrays, - client_->RemapArrays({ - /*input_specs=*/std::move(input_specs), - /*output_specs=*/std::move(output_specs), - /*mappings=*/std::move(mappings), - }, - absl::MakeSpan(inputs), copy_semantics)); - - for (const auto& array_value : array_values_to_gc_from_env) { - // Donated remapped arrays are pro-actively deleted, and aliased arrays - // cannot be deleted later. Thus, remove the arrays from the deletable - // program arguments set. - deletable_program_arguments_.erase(array_value); - env.value_to_array.erase(array_value); - } + state.remap_plan = xla::ifrt::RemapPlan{ + /*input_specs=*/std::move(input_specs), + /*output_specs=*/std::move(output_specs), + /*mappings=*/std::move(mappings), + }; + state.remap_is_donated = remap_op.getDonated(); - // Store the result arrays in the environment. - TF_RET_CHECK(out_arrays.size() == remap_op.getOutputs().size()) - << "Got " << out_arrays.size() << " results, but op has " - << remap_op.getOutputs().size() << ". " << PrettyPrint(remap_op); - for (const auto [output_array, output] : - llvm::zip(out_arrays, remap_op.getOutputs())) { - if (!output.use_empty()) { - env.AssociateArray(output, ArrayState{/*array=*/std::move(output_array), - /*can_be_donated=*/true}); - } + for (const auto output : remap_op.getOutputs()) { + const ArrayHandle handle = output.use_empty() ? 0 : ToArrayHandle(output); + state.output_handles.push_back(handle); } - return absl::OkStatus(); + + return absl::bind_front(&RemapArraysOpState::Run, std::move(state)); } -absl::Status ProgramInterpreter::ExecuteOp( - xla::ifrt::CopyArraysOp copy_arrays_op, Environment& env) { - TraceMe traceme([&]() { - return TraceMeEncode("DispatchCopyArraysOp", - {{"ifrt_ir_program", program_->program_name}}); - }); - std::string op_name = copy_arrays_op->getName().getStringRef().str(); - VLOG(3) << PrettyPrint(copy_arrays_op); - - std::vector inputs; - inputs.reserve(copy_arrays_op.getInputs().size()); - llvm::DenseSet array_values_to_gc_from_env; - std::optional is_donated; - for (const auto [idx, input] : llvm::enumerate(copy_arrays_op.getInputs())) { - auto array_it = env.value_to_array.find(input); - TF_RET_CHECK(array_it != env.value_to_array.end()) - << "Input array #" << idx << " not found. " - << PrettyPrint(copy_arrays_op); - if (array_it->second.array->IsDeleted()) { - // We explicitly check here for deletion in order to provide a more - // informative error message. - return absl::InvalidArgumentError(absl::StrCat( - "Input array #", idx, " has already been deleted or donated. ", - PrettyPrint(copy_arrays_op))); +namespace { + +struct CopyArraysOpState { + std::string pretty_print; + + std::vector input_handles; + absl::flat_hash_set dead_inputs; + bool copy_is_donated; + + std::vector output_handles; + xla::ifrt::ShardingRef new_sharding; + + absl::Status Run(Environment& env) const { + TraceMe traceme([&]() { + return TraceMeEncode("DispatchCopyArraysOp", + {{"ifrt_ir_program", env.program_name}}); + }); + VLOG(3) << pretty_print; + + std::vector inputs; + inputs.reserve(input_handles.size()); + + std::optional is_donated; + std::vector arrays_to_remove; + + for (int idx = 0; idx < input_handles.size(); ++idx) { + const ArrayHandle handle = input_handles[idx]; + + auto array_it = env.handle_to_array.find(handle); + TF_RET_CHECK(array_it != env.handle_to_array.end()) + << "Input array #" << idx << " not found. " << pretty_print; + if (array_it->second.array->IsDeleted()) { + // We explicitly check here for deletion in order to provide a more + // informative error message. + return absl::InvalidArgumentError(absl::StrCat( + "Input array #", idx, " has already been deleted or donated. ", + pretty_print)); + } + inputs.push_back(array_it->second.array); + + // The default buffer donation semantic is finalized at compilation time. + // Users can override the donation semantic at runtime. In the meantime, + // the IFRT client CopyArrays API requires all input arrays have the same + // donation semantic. + if (!is_donated.has_value()) { + is_donated = copy_is_donated && array_it->second.can_be_donated; + } + if (*is_donated && !array_it->second.can_be_donated) { + return absl::InvalidArgumentError(absl::StrCat( + "Donation semantic must be consistent across all input arrays of " + "CopyArraysOp. Input array #", + idx, + " cannot be donated, but previous input arrays can be donated. " + "It's likely due to a MPMD program argument is marked as " + "non-donatable. ", + pretty_print)); + } + if (*is_donated || dead_inputs.contains(handle)) { + arrays_to_remove.push_back(handle); + } } - inputs.push_back(array_it->second.array); - - // The default buffer donation semantic is finalized at compilation time. - // Users can override the donation semantic at runtime. In the meantime, the - // IFRT client CopyArrays API requires all input arrays have the same - // donation semantic. - if (!is_donated.has_value()) { - is_donated = - copy_arrays_op.getDonated() && array_it->second.can_be_donated; + TF_RET_CHECK(is_donated.has_value()) + << "Unable to determine the donation semantic of the copy arrays op. " + "The copy arrays op has no inputs. " + << pretty_print; + + auto array_copy_semantics = + *is_donated ? xla::ifrt::ArrayCopySemantics::kDonateInput + : xla::ifrt::ArrayCopySemantics::kAlwaysCopy; + // It is safe to get the devices and memory kind from the first output + // because all outputs use the same devices and have the same memory kind. + TF_ASSIGN_OR_RETURN(auto copied_arrays, + env.client->CopyArrays( + absl::MakeSpan(inputs), new_sharding->devices(), + new_sharding->memory_kind(), array_copy_semantics)); + + for (const auto handle : arrays_to_remove) { + if (env.deletable_program_arguments.erase(handle)) { + // Explicitly delete donated program arguments that are not used later. + env.handle_to_array[handle].array->Delete(); + } + env.handle_to_array.erase(handle); } - if (*is_donated && !array_it->second.can_be_donated) { - return absl::InvalidArgumentError(absl::StrCat( - "Donation semantic must be consistent across all input arrays of " - "CopyArraysOp. Input array #", - idx, - " cannot be donated, but previous input arrays can be donated. It's " - "likely due to a MPMD program argument is marked as non-donatable. ", - PrettyPrint(copy_arrays_op))); + + TF_RET_CHECK(copied_arrays.size() == inputs.size()) + << "Got " << copied_arrays.size() << " results, but op has " + << inputs.size() << ". " << pretty_print; + for (int i = 0; i < output_handles.size(); ++i) { + const ArrayHandle handle = output_handles[i]; + if (handle != 0) { + env.AssociateArray(handle, ArrayState{ + /*array=*/std::move(copied_arrays[i]), + /*can_be_donated=*/true, + }); + } } - if (*is_donated || liveness_.isDeadAfter(input, copy_arrays_op)) { - array_values_to_gc_from_env.insert(input); + + return absl::OkStatus(); + } +}; + +} // namespace + +absl::StatusOr ProgramInterpreter::HandleOp( + xla::ifrt::CopyArraysOp copy_arrays_op) { + CopyArraysOpState state; + state.pretty_print = PrettyPrint(copy_arrays_op); + + for (const auto [idx, input] : llvm::enumerate(copy_arrays_op.getInputs())) { + state.input_handles.push_back(ToArrayHandle(input)); + if (liveness_.isDeadAfter(input, copy_arrays_op)) { + state.dead_inputs.insert(ToArrayHandle(input)); } } - TF_RET_CHECK(is_donated.has_value()) - << "Unable to determine the donation semantic of the copy arrays op. The " - "copy arrays op has no inputs. " - << PrettyPrint(copy_arrays_op); + state.copy_is_donated = copy_arrays_op.getDonated(); const auto out_array_type = llvm::cast( copy_arrays_op.getOutputs().front().getType()); TF_RET_CHECK(out_array_type != nullptr) << "Output array #0 is not of type `IfrtArrayType`. " - << PrettyPrint(copy_arrays_op); - auto new_sharding = array_type_to_sharding_.at(out_array_type); - auto array_copy_semantics = *is_donated - ? xla::ifrt::ArrayCopySemantics::kDonateInput - : xla::ifrt::ArrayCopySemantics::kAlwaysCopy; - // It is safe to get the devices and memory kind from the first output - // because all outputs use the same devices and have the same memory kind. - TF_ASSIGN_OR_RETURN( - auto copied_arrays, - client_->CopyArrays(absl::MakeSpan(inputs), new_sharding->devices(), - new_sharding->memory_kind(), array_copy_semantics)); - - for (const auto& array_value : array_values_to_gc_from_env) { - if (deletable_program_arguments_.erase(array_value)) { - // Explicitly delete donated program arguments that are not used later. - env.value_to_array[array_value].array->Delete(); - } - env.value_to_array.erase(array_value); + << state.pretty_print; + TF_ASSIGN_OR_RETURN(state.new_sharding, + GetSharding(out_array_type, client_, devices_)); + + for (const auto output : copy_arrays_op.getOutputs()) { + const ArrayHandle handle = output.use_empty() ? 0 : ToArrayHandle(output); + state.output_handles.push_back(handle); } - // Store the result arrays in the environment. - TF_RET_CHECK(copied_arrays.size() == copy_arrays_op.getOutputs().size()) - << "Got " << copied_arrays.size() << " results, but op has " - << copy_arrays_op.getOutputs().size() << ". " - << PrettyPrint(copy_arrays_op); - for (const auto [output_array, output] : - llvm::zip(copied_arrays, copy_arrays_op.getOutputs())) { - if (!output.use_empty()) { - env.AssociateArray(output, ArrayState{/*array=*/std::move(output_array), - /*can_be_donated=*/true}); + return absl::bind_front(&CopyArraysOpState::Run, std::move(state)); +} + +namespace { + +struct ReturnOpState { + std::string pretty_print; + std::vector output_handles; + + absl::Status Run(Environment& env) const { + VLOG(3) << "func.return of `main` function"; + env.outputs.reserve(output_handles.size()); + for (int idx = 0; idx < output_handles.size(); ++idx) { + auto array_it = env.handle_to_array.find(output_handles[idx]); + TF_RET_CHECK(array_it != env.handle_to_array.end()) + << "Input array #" << idx << " not found. " << pretty_print; + env.outputs.push_back(std::move(array_it->second.array)); } + env.handle_to_array.clear(); + return absl::OkStatus(); } - return absl::OkStatus(); -} +}; + +} // namespace + +absl::StatusOr ProgramInterpreter::HandleOp( + mlir::func::ReturnOp return_op) { + ReturnOpState state; + state.pretty_print = PrettyPrint(return_op); -absl::Status ProgramInterpreter::ExecuteOp(mlir::func::ReturnOp return_op, - Environment& env) { auto func_op = return_op->getParentOfType(); CHECK_EQ(func_op.getSymName().str(), "main"); - VLOG(3) << return_op->getName().getStringRef().str() << " of `main` function"; - env.outputs.reserve(return_op->getNumOperands()); + state.output_handles.reserve(return_op->getNumOperands()); for (const auto& [idx, result] : llvm::enumerate(return_op.getOperands())) { - auto array_it = env.value_to_array.find(result); - TF_RET_CHECK(array_it != env.value_to_array.end()) - << "Input array #" << idx << " not found. " << PrettyPrint(return_op); - env.outputs.push_back(std::move(array_it->second.array)); + state.output_handles.push_back(ToArrayHandle(result)); } - env.value_to_array.clear(); - return absl::OkStatus(); + + return absl::bind_front(&ReturnOpState::Run, std::move(state)); } std::string ProgramInterpreter::PrettyPrint(mlir::Operation* op) { diff --git a/third_party/xla/xla/python/ifrt/ir/program_interpreter.h b/third_party/xla/xla/python/ifrt/ir/program_interpreter.h index 3f8e8075404185..35158ac1305124 100644 --- a/third_party/xla/xla/python/ifrt/ir/program_interpreter.h +++ b/third_party/xla/xla/python/ifrt/ir/program_interpreter.h @@ -21,23 +21,22 @@ limitations under the License. #include #include +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/SymbolTable.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/executable.h" -#include "xla/python/ifrt/ir/compiled_ifrt_ir_program.h" -#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" #include "xla/python/ifrt/ir/ifrt_ops.h" -#include "xla/python/ifrt/sharding.h" namespace xla { namespace ifrt { @@ -46,59 +45,75 @@ namespace ifrt { struct Environment; // Interpreter for an IFRT IR program. +// +// The program interpreter is responsible for executing an IFRT IR program. The +// interpreter works in two stages. First, when `BuildExecuteFn` is called, it +// traverses the program and builds a function that can be invoked to execute +// the program, which happens only once during compilation. Second, the returned +// execute function can be called multiple times to interpret the IFRT IR +// program. +// +// This two-stage design has two primary purposes: +// +// 1. It allows us to leverage the static information available in the program +// as much as possible. For example, `RemapArraysOp` builds its remap plan +// during the first stage and the plan is reused for all executions. +// +// 2. It avoids running any LLVM/MLIR code during execution. This is +// particularly useful in environments where the use of LLVM/MLIR +// synchronization primitives may cause deadlocks, e.g., cooperatively +// scheduled fibers. class ProgramInterpreter { public: + using ExecuteFn = absl::AnyInvocable< + absl::StatusOr( + absl::Span arrays, + const xla::ifrt::LoadedExecutable::ExecuteOptions& options, + std::optional devices)>; + static absl::StatusOr> Create( - xla::ifrt::Client* client, std::shared_ptr program, + xla::ifrt::Client* client, absl::string_view program_name, + mlir::ModuleOp mlir_module, + std::shared_ptr atom_program_executables, xla::ifrt::DeviceListRef devices); - // Executes the IFRT IR program. - absl::StatusOr Execute( - absl::Span arrays, - const xla::ifrt::LoadedExecutable::ExecuteOptions& options, - std::optional devices); + absl::StatusOr BuildExecuteFn(); private: + using OpFn = absl::AnyInvocable; + ProgramInterpreter( - xla::ifrt::Client* client, std::shared_ptr program, - xla::ifrt::DeviceListRef devices, mlir::Liveness liveness, - llvm::DenseMap - array_type_to_sharding) + xla::ifrt::Client* client, absl::string_view program_name, + mlir::ModuleOp mlir_module, + std::shared_ptr atom_program_executables, + xla::ifrt::DeviceListRef devices, mlir::Liveness liveness) : client_(client), - program_(std::move(program)), + program_name_(program_name), + mlir_module_(mlir_module), + atom_program_executables_(std::move(atom_program_executables)), devices_(std::move(devices)), - liveness_(std::move(liveness)), - array_type_to_sharding_(std::move(array_type_to_sharding)) {} + liveness_(std::move(liveness)) {} - absl::Status ExecuteOp(xla::ifrt::CallLoadedExecutableOp call_loaded_op, - Environment& env); - absl::Status ExecuteOp(xla::ifrt::RemapArraysOp remap_op, Environment& env); - absl::Status ExecuteOp(xla::ifrt::CopyArraysOp copy_arrays_op, - Environment& env); - absl::Status ExecuteOp(mlir::func::ReturnOp return_op, Environment& env); + absl::StatusOr HandleOp( + xla::ifrt::CallLoadedExecutableOp call_loaded_op); + absl::StatusOr HandleOp(xla::ifrt::RemapArraysOp remap_op); + absl::StatusOr HandleOp(xla::ifrt::CopyArraysOp copy_arrays_op); + absl::StatusOr HandleOp(mlir::func::ReturnOp return_op); // Returns a pretty string representation of the op. std::string PrettyPrint(mlir::Operation* op); xla::ifrt::Client* client_; mlir::SymbolTableCollection symbol_table_; - std::shared_ptr program_; + std::string program_name_; + mlir::ModuleOp mlir_module_; + std::shared_ptr atom_program_executables_; // All the devices the program uses. xla::ifrt::DeviceListRef devices_; // Cached liveness analysis of the IFRT IR program. mlir::Liveness liveness_; - - // Mapping between IfrtArrayType and Sharding. This map is used to cache - // the Shardings at IFRT IR program compilation time in order to avoid - // overheads at execution time. - llvm::DenseMap - array_type_to_sharding_; - - // Set of donated program arguments, which can be deleted after their last - // use. Entries are removed upon deletion or if they are aliased. - llvm::DenseSet deletable_program_arguments_; }; } // namespace ifrt From 0de1b0f2ad88d1c14bb825102aa0d10f909a319b Mon Sep 17 00:00:00 2001 From: Maxim Ermilov Date: Mon, 8 Dec 2025 11:24:53 -0800 Subject: [PATCH 040/753] Only show warning about nvml symbol when relevant PiperOrigin-RevId: 841846078 --- third_party/xla/xla/stream_executor/cuda/cuda_executor.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 488e0f465f594a..cc487db3345103 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -1837,8 +1837,10 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) { info.cluster_uuid = fabric_info->cluster_uuid; info.clique_id = fabric_info->clique_id; } else { - LOG(WARNING) << "GPU interconnect information not available: " - << fabric_info.status(); + if (cc.IsAtLeastHopper() && p2p_link_count.ok() && *p2p_link_count) { + LOG(WARNING) << "GPU interconnect information not available: " + << fabric_info.status(); + } } desc.set_device_interconnect_info(info); } From d70ec822d2566d485e52c8f52cb0a4b75f75c775 Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Mon, 8 Dec 2025 11:35:17 -0800 Subject: [PATCH 041/753] Add `DCNTopology` and `EndpointAddresses` PiperOrigin-RevId: 841850899 --- third_party/xla/opensource_only.files | 1 + third_party/xla/xla/megascale/BUILD | 28 ++++++++ third_party/xla/xla/megascale/addresses.proto | 28 ++++++++ .../xla/xla/megascale/dcn_topology.proto | 64 +++++++++++++++++++ .../xla/xla/megascale/package_groups.bzl | 7 ++ 5 files changed, 128 insertions(+) create mode 100644 third_party/xla/xla/megascale/BUILD create mode 100644 third_party/xla/xla/megascale/addresses.proto create mode 100644 third_party/xla/xla/megascale/dcn_topology.proto create mode 100644 third_party/xla/xla/megascale/package_groups.bzl diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 4a78380bc9dd7d..888a0978aac8f0 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -1,6 +1,7 @@ tensorflow/compiler/xla/backends/cpu/nanort/package_groups.bzl: tensorflow/compiler/xla/backends/cpu/package_groups.bzl: tensorflow/compiler/xla/internal/package_groups.bzl: +tensorflow/compiler/xla/megascale/package_groups.bzl: tensorflow/compiler/xla/mlir_hlo/WORKSPACE: tensorflow/compiler/xla/package_groups.bzl: tensorflow/compiler/xla/pjrt/cpu/package_groups.bzl: diff --git a/third_party/xla/xla/megascale/BUILD b/third_party/xla/xla/megascale/BUILD new file mode 100644 index 00000000000000..4aa6b4e8f5498e --- /dev/null +++ b/third_party/xla/xla/megascale/BUILD @@ -0,0 +1,28 @@ +load("//xla/megascale:package_groups.bzl", "megascale_package_groups") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load( + "//xla/tsl/platform:build_config.bzl", + "tf_proto_library", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":internal"]), + licenses = ["notice"], +) + +megascale_package_groups() + +tf_proto_library( + name = "dcn_topology_proto", + srcs = ["dcn_topology.proto"], + create_grpc_library = True, + make_default_target_header_only = True, +) + +tf_proto_library( + name = "addresses_proto", + srcs = ["addresses.proto"], + create_grpc_library = True, + make_default_target_header_only = True, +) diff --git a/third_party/xla/xla/megascale/addresses.proto b/third_party/xla/xla/megascale/addresses.proto new file mode 100644 index 00000000000000..ad6611335bbc3c --- /dev/null +++ b/third_party/xla/xla/megascale/addresses.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package xla.megascale.runtime; + +option java_multiple_files = true; +option java_outer_classname = "Runtime"; + +message HostNetworkAddress { + string address = 1; + string interface_name = 2; + // The host name used for debugging only, and is supplied by pathways or MXLA + // coordinator. Do not use this for creating connection to other peers, use + // the address above. + string host_name_for_debugging = 3; +} + +// NetworkAddressMapping provides mapping between a unique endpoint (slice_id, +// host_id) and the network address it is reachable at. +message NetworkAddressMapping { + int32 slice_id = 1; + int32 host_id = 2; + repeated HostNetworkAddress addresses = 3; +} + +// Holds the network address mapping of all endpoints (slice_id, host_id). +message EndpointAddresses { + repeated NetworkAddressMapping address_mappings = 1; +} diff --git a/third_party/xla/xla/megascale/dcn_topology.proto b/third_party/xla/xla/megascale/dcn_topology.proto new file mode 100644 index 00000000000000..de87573195e36b --- /dev/null +++ b/third_party/xla/xla/megascale/dcn_topology.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package xla.megascale.runtime; + +option java_multiple_files = true; +option java_outer_classname = "Runtime"; + +message DCNTopology { + // SymmetricTree represents a simple network topology with symmetric + // splitting at each level. + message SymmetricTree { + // The length of branching_per_layer is the depth (number of distinct + // layers) of the network topology. The values give the branching factor at + // each layer. Index 0 holds the uppermost level in the topology. For + // example: a 24 slice topology, in three groups of two subgroups of four + // slices would be represented as: branching_per_layer = + // [3, 2, 4] slice_ids are not explicitly specified and are assumed to be + // contiguously assigned. i.e. slice_id = branching_per_layer[0] * 8 + + // branching_per_layer[1] * 4 + branching_per_layer[2] + repeated int32 branching_per_layer = 1; + } + + // Node recursively defines a fully specified tree. The tree is expected to + // be balanced but allowed to be asymmetric. + message TreeNode { + // Contiguous range of slices in half-open interval [slice_id_start, + // slice_id_end). The contiguous nature has no special signficance beyond + // compactly represent large number of slices. e.g. SliceRange{0, 10} and + // SliceRange{20, 30} all have the same connectivity between them. + message SliceRange { + int32 slice_id_start = 1; + // Ignored when slice_id_end <= slice_id_start. + int32 slice_id_end = 2; + } + + // Optional label for readability. + optional string label = 1; + + // We expect the Topology to be a balanced asymmetric tree. This implies + // that at any level we should either have nodes OR slice_ranges. + repeated TreeNode nodes = 2; + repeated SliceRange slice_ranges = 3; + + // Specifies the degree to which egress from this node to higher layers in + // topology is constrained. Valid range [0.0, 1.0]. 0.0 -> no + // constraint, 1.0 -> never use. When egress_constraint for a node is higher + // than other nodes with which it performas a reduction, it is assigned + // shards for reduction with less probability. This will result in fewer + // transfers out of these nodes to higher layers in topology. + optional float egress_constraint = 4; + + // Whether to perform the ring algorithm instead of the shuffle algorithm + // between the children. The ring order is the order of the children. + bool ring_transfers = 5; + } + + oneof representation { + // Simple representation of a symmetric hierarchical network. + SymmetricTree symmetric_tree = 1; + // Fully specified tree with no assumptions on symmetry and slice id + // mappings. + TreeNode tree = 2; + } +} diff --git a/third_party/xla/xla/megascale/package_groups.bzl b/third_party/xla/xla/megascale/package_groups.bzl new file mode 100644 index 00000000000000..9d3f8d701a735b --- /dev/null +++ b/third_party/xla/xla/megascale/package_groups.bzl @@ -0,0 +1,7 @@ +"""Megascale package_group definitions""" + +def megascale_package_groups(name = "megascale_package_groups"): + native.package_group( + name = "internal", + packages = ["//..."], + ) From cc327345d12c6e82a0e0c8345b66cfa30844ea16 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Mon, 8 Dec 2025 11:50:09 -0800 Subject: [PATCH 042/753] Optimize remap ops in IFRT IR by performing more work at compile time With two-stage program interpreter, we can now calculate `input_devices_for_output_map` and verify the remap plan at compile time without having to worry about their overheads. The former may improve the performance for runtime implementations that leverage the additional information. PiperOrigin-RevId: 841856704 --- third_party/xla/xla/python/ifrt/ir/program_interpreter.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc index deddd328048a7d..313977f13d1030 100644 --- a/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc +++ b/third_party/xla/xla/python/ifrt/ir/program_interpreter.cc @@ -615,6 +615,9 @@ absl::StatusOr ProgramInterpreter::HandleOp( }; state.remap_is_donated = remap_op.getDonated(); + TF_RETURN_IF_ERROR(state.remap_plan.ComputeInputDevicesForOutputMap(client_)); + TF_RETURN_IF_ERROR(state.remap_plan.Validate()); + for (const auto output : remap_op.getOutputs()) { const ArrayHandle handle = output.use_empty() ? 0 : ToArrayHandle(output); state.output_handles.push_back(handle); From a3f49eccd20479d5f910b481e15c076c741737db Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 8 Dec 2025 12:04:35 -0800 Subject: [PATCH 043/753] [xla] Migrate XLA to MaybeOwningDeviceMemory PiperOrigin-RevId: 841862688 --- tensorflow/compiler/jit/xla_launch_util.h | 1 + .../xla/xla/backends/cpu/autotuner/BUILD | 2 +- .../backends/cpu/autotuner/cpu_profiler.cc | 4 +-- .../xla/backends/cpu/autotuner/cpu_profiler.h | 6 ++-- .../xla/xla/backends/cpu/runtime/BUILD | 2 +- .../backends/cpu/runtime/buffer_allocations.h | 7 ++-- .../xla/xla/backends/gpu/autotuner/BUILD | 2 +- .../backends/gpu/autotuner/gpu_profiler.cc | 4 +-- .../xla/xla/backends/interpreter/BUILD | 2 +- .../backends/interpreter/executable_base.cc | 8 ++--- third_party/xla/xla/client/BUILD | 2 +- third_party/xla/xla/client/local_client.cc | 4 +-- third_party/xla/xla/client/local_client.h | 2 +- third_party/xla/xla/pjrt/BUILD | 2 +- third_party/xla/xla/pjrt/cpu/BUILD | 2 +- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 6 ++-- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 2 +- .../xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc | 14 ++++---- .../pjrt/gpu/tfrt/tracked_gpu_device_buffer.h | 2 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 6 ++-- .../xla/xla/pjrt/tracked_device_buffer.h | 10 +++--- third_party/xla/xla/service/BUILD | 6 ++-- third_party/xla/xla/service/cpu/BUILD | 2 +- .../xla/xla/service/cpu/cpu_executable.cc | 32 +++++++++---------- .../xla/xla/service/cpu/cpu_executable.h | 11 ++++--- third_party/xla/xla/service/executable.cc | 10 +++--- third_party/xla/xla/service/executable.h | 28 ++++++++-------- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/xla/service/gpu/autotuning/BUILD | 2 +- .../gpu/autotuning/autotuner_compile_util.cc | 4 +-- .../xla/xla/service/gpu/gpu_executable.cc | 6 ++-- third_party/xla/xla/service/hlo_runner.cc | 6 ++-- .../xla/xla/service/transfer_manager.cc | 5 +-- .../xla/xla/service/transfer_manager.h | 4 +-- third_party/xla/xla/stream_executor/tpu/BUILD | 4 +-- .../stream_executor/tpu/c_api_conversions.cc | 16 +++++----- .../stream_executor/tpu/c_api_conversions.h | 7 ++-- .../xla/xla/stream_executor/tpu/c_api_decl.h | 6 ++-- .../tpu/tpu_executable_interface.cc | 6 ++-- third_party/xla/xla/tests/BUILD | 2 +- .../xla/xla/tests/buffer_donation_test.cc | 7 ++-- 41 files changed, 132 insertions(+), 124 deletions(-) diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 2876b3a7b96373..401f15587fcf39 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/local_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/service/maybe_owning_device_memory.h" #include "xla/service/shaped_buffer.h" #include "xla/stream_executor/device_memory_allocator.h" #include "tensorflow/core/framework/allocation_description.pb.h" diff --git a/third_party/xla/xla/backends/cpu/autotuner/BUILD b/third_party/xla/xla/backends/cpu/autotuner/BUILD index 16640e22a3a8f5..81247efe7e118a 100644 --- a/third_party/xla/xla/backends/cpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/cpu/autotuner/BUILD @@ -46,7 +46,7 @@ cc_library( "//xla/backends/autotuner:profiler", "//xla/service:buffer_assignment", "//xla/service:executable", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service/cpu:cpu_executable", "//xla/tsl/platform:errors", diff --git a/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.cc b/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.cc index 8d29f0afdb15ef..6e841511d97c18 100644 --- a/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.cc +++ b/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/shape_util.h" #include "xla/tsl/platform/errors.h" #include "xla/xla_data.pb.h" @@ -94,7 +94,7 @@ absl::StatusOr CpuProfiler::Profile( } absl::Status CpuProfiler::Execute( - Executable* executable, absl::Span buffers, + Executable* executable, absl::Span buffers, ExecutionProfile* profile) { ExecutableRunOptions run_options; run_options.set_execution_profile(profile); diff --git a/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.h b/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.h index cb62437957c187..5d5f32c780cd20 100644 --- a/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.h +++ b/third_party/xla/xla/backends/cpu/autotuner/cpu_profiler.h @@ -25,7 +25,7 @@ limitations under the License. #include "xla/backends/autotuner/profiler.h" #include "xla/literal.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/xla_data.pb.h" @@ -33,7 +33,7 @@ namespace xla::cpu { struct LiteralBackedCpuBuffers : public InputBuffers { std::vector backing_literals; - std::vector buffers; + std::vector buffers; }; class CpuProfiler : public Profiler { @@ -60,7 +60,7 @@ class CpuProfiler : public Profiler { explicit CpuProfiler(ProfileOptions options) : options_(options) {} absl::Status Execute(Executable* executable, - absl::Span buffers, + absl::Span buffers, ExecutionProfile* profile); private: diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index e1e85a5cd4675a..026a0476d92786 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -56,7 +56,7 @@ cc_library( deps = [ "//xla:util", "//xla/service:buffer_assignment", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/stream_executor:device_address", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", diff --git a/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h index b1f7e8142d2939..d91f41dcec389a 100644 --- a/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/stream_executor/device_address.h" #include "xla/util.h" @@ -40,7 +40,8 @@ class BufferAllocations { explicit BufferAllocations(Buffers buffers); explicit BufferAllocations(absl::Span buffers); - explicit BufferAllocations(absl::Span buffers); + explicit BufferAllocations( + absl::Span buffers); // Returns the device address of buffer at the given index. Returns an error // if the index is out of range. @@ -80,7 +81,7 @@ inline BufferAllocations::BufferAllocations( num_buffers_(buffers_.size()) {} inline BufferAllocations::BufferAllocations( - absl::Span buffers) + absl::Span buffers) : buffers_(buffers.size()), buffers_data_(buffers_.data()), num_buffers_(buffers_.size()) { diff --git a/third_party/xla/xla/backends/gpu/autotuner/BUILD b/third_party/xla/xla/backends/gpu/autotuner/BUILD index 3d1f9c93001508..662e3f7e03ddf5 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/BUILD +++ b/third_party/xla/xla/backends/gpu/autotuner/BUILD @@ -507,7 +507,7 @@ cc_library( "//xla/backends/gpu/runtime:buffer_comparator", "//xla/hlo/ir:hlo", "//xla/service:executable", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu/autotuning:redzone_buffers", diff --git a/third_party/xla/xla/backends/gpu/autotuner/gpu_profiler.cc b/third_party/xla/xla/backends/gpu/autotuner/gpu_profiler.cc index 81b5135600507c..82c8405af97e3d 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/gpu_profiler.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/gpu_profiler.cc @@ -33,7 +33,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/redzone_buffers.h" #include "xla/service/gpu/gpu_executable_run_options.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -62,7 +62,7 @@ std::vector CreateExecutionInputsFromBuffers( // Our executable doesn't have input-output aliasing, so we can pass // unowned input buffers. inputs.back().SetUnownedBuffer( - /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i))); + /*index=*/{}, MaybeOwningDeviceAddress(/*unowned=*/buffers.at(i))); } return inputs; } diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 6bf5957323a49c..b8af7e32ada24e 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -84,7 +84,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:dynamic_dimension_inference", "//xla/service:executable", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor:device_address", diff --git a/third_party/xla/xla/backends/interpreter/executable_base.cc b/third_party/xla/xla/backends/interpreter/executable_base.cc index 7ba92f41d87701..eb7fa5d4c07832 100644 --- a/third_party/xla/xla/backends/interpreter/executable_base.cc +++ b/third_party/xla/xla/backends/interpreter/executable_base.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" @@ -73,7 +73,7 @@ absl::StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( device_ordinal = 0; } for (auto& argument : arguments) { - const ShapeTree& buffers = argument.Buffers(); + const ShapeTree& buffers = argument.Buffers(); argument_buffers.push_back(ShapedBuffer(buffers.shape(), /*device_ordinal=*/device_ordinal)); auto in_it = buffers.begin(); @@ -179,7 +179,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( -> absl::Status { if (alias && alias->must_alias()) { VLOG(1) << alias->ToString(); - const MaybeOwningDeviceMemory& original_input = + const MaybeOwningDeviceAddress& original_input = (*arguments)[alias->parameter_number].Buffers().element( alias->parameter_index); if (!original_input.HasOwnership()) { @@ -214,7 +214,7 @@ InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse( if (alias) { TF_RET_CHECK(alias->parameter_number < arguments->size()); ExecutionInput& input = (*arguments)[alias->parameter_number]; - MaybeOwningDeviceMemory* device_memory = + MaybeOwningDeviceAddress* device_memory = input.MutableBuffer(alias->parameter_index); if (auto owning = device_memory->Release()) { se::DeviceAddressBase device_memory_base = owning->Release(); diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 8fa2963bc27550..c2801fa3fa8410 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -124,7 +124,7 @@ cc_library( "//xla/service:dump", "//xla/service:executable", "//xla/service:local_service", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:source_map_util", "//xla/service:stream_pool", diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 3c865b508f5700..cc383a9aa81b34 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -37,7 +37,7 @@ limitations under the License. #include "xla/service/computation_layout.h" #include "xla/service/dump.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/source_map_util.h" @@ -319,7 +319,7 @@ absl::StatusOr LocalExecutable::RunAsync( } static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer( - const ShapeTree& tree, int device_ordinal) { + const ShapeTree& tree, int device_ordinal) { ShapedBuffer result(tree.shape(), device_ordinal); auto it = tree.begin(); auto out_it = result.buffers().begin(); diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 3f06595a88500d..3ccda5d43f6794 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/local_service.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 11fa3d24c990c8..6b382142dbb42f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -710,7 +710,7 @@ cc_library( "//xla/service:generic_transfer_manager", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 59bcdbffdc173f..6e7d2fad54a6dc 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -202,7 +202,7 @@ cc_library( "//xla/service:hlo_module_util", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_executable", "//xla/service/cpu:cpu_executable_run_options", diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 9a2a4dcb42d319..5e2f7aa65df9ef 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -114,7 +114,7 @@ limitations under the License. #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_value.h" #include "xla/service/llvm_ir/llvm_command_line_options.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" @@ -1620,7 +1620,7 @@ absl::StatusOr PjRtCpuExecutable::ExecuteHelper( if (cpu_executable->has_thunks()) { // Call interpreted thunk sequence implementing XLA executable. - absl::InlinedVector buffer_device_mem; + absl::InlinedVector buffer_device_mem; buffer_device_mem.reserve(buffer_table.size()); for (const auto& buffer_info : buffer_table) { buffer_device_mem.emplace_back( @@ -1764,7 +1764,7 @@ absl::StatusOr PjRtCpuExecutable::ExecuteHelper( absl::Status status; if (cpu_executable->has_thunks()) { // Call interpreted thunk sequence implementing XLA executable. - absl::InlinedVector buffer_device_mem; + absl::InlinedVector buffer_device_mem; buffer_device_mem.reserve(buffer_table.size()); for (const auto& buffer_info : buffer_table) { buffer_device_mem.emplace_back( diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index 92f3be08e16c1a..b6d28c5e744e7f 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -100,7 +100,7 @@ cc_library( "//xla/service:generic_transfer_manager", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc index 7dacb707060297..1c97ab898cbd21 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc @@ -69,7 +69,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/hlo.pb.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" @@ -902,19 +902,19 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( std::vector inputs; if (parameter_is_tupled_arguments) { inputs.emplace_back( - ShapeTree(¶meter_shapes->front())); + ShapeTree(¶meter_shapes->front())); ExecutionInput& input = inputs.back(); for (int i = 0; i < tracked_buffers.size(); ++i) { VLOG(4) << "tupled input[" << i << "]: " << tracked_buffers[i]->buffer()->buffer().opaque(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {i}, MaybeOwningDeviceMemory(se::OwningDeviceMemory( + {i}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( tracked_buffers[i]->buffer()->buffer(), device->local_hardware_id().value(), client->allocator()))); } else { - input.SetBuffer({i}, MaybeOwningDeviceMemory( + input.SetBuffer({i}, MaybeOwningDeviceAddress( tracked_buffers[i]->buffer()->buffer())); } } @@ -924,16 +924,16 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( VLOG(4) << "untupled input[" << i << "]: " << tracked_buffers[i]->buffer()->buffer().opaque(); inputs.emplace_back( - ShapeTree(&(*parameter_shapes)[i])); + ShapeTree(&(*parameter_shapes)[i])); ExecutionInput& input = inputs.back(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {}, MaybeOwningDeviceMemory(se::OwningDeviceMemory( + {}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( tracked_buffers[i]->buffer()->buffer(), device->local_hardware_id().value(), client->allocator()))); } else { - input.SetBuffer({}, MaybeOwningDeviceMemory( + input.SetBuffer({}, MaybeOwningDeviceAddress( tracked_buffers[i]->buffer()->buffer())); } } diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h index 3a1b1bc186f1e9..19c949075f320d 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h @@ -35,7 +35,7 @@ limitations under the License. #include "xla/tsl/concurrency/async_value_ref.h" namespace xla { -// TODO(b/400541410): Refactor and Merge this with MaybeOwningDeviceMemory. +// TODO(b/400541410): Refactor and Merge this with MaybeOwningDeviceAddress. // GpuDeviceMemory represents either an owned or unowned GPU memory. It // owns GPU memory if an allocator is provided. When the object goes output of diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 69635421f80399..4c175b7390e14c 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -141,7 +141,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" @@ -1673,12 +1673,12 @@ PjRtStreamExecutorClient::RunAsync( auto it = tmp.MutableBuffers()->begin(); for (auto& v : input) { if (v.second.is_donated) { - it->second = MaybeOwningDeviceMemory(se::OwningDeviceMemory( + it->second = MaybeOwningDeviceAddress(se::OwningDeviceMemory( v.second.buf->mem(), device->local_device_id().value(), run_options.allocator())); tmp.SetUnownedIndex(it->first); } else { - it->second = MaybeOwningDeviceMemory(v.second.buf->mem()); + it->second = MaybeOwningDeviceAddress(v.second.buf->mem()); } ++it; } diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index ecc4a64dc73c45..62b36de4923881 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -39,7 +39,7 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" @@ -109,8 +109,8 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { // on_device_shape matches that of the TrackedDeviceBuffer. 'end' is used to // check that 'iterator' doesn't run out of bounds. void AddToInputAsImmutable( - ShapeTree::iterator* iterator, - const ShapeTree::iterator& end) const; + ShapeTree::iterator* iterator, + const ShapeTree::iterator& end) const; // Adds the owned device buffers in order to 'iterator', marking them as // available to be donated. If donation succeeds, i.e., execution_input is @@ -121,8 +121,8 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { // that of the TrackedDeviceBuffer. 'end' is used to check that 'iterator' // doesn't run out of bounds. void AddToInputAsDonated( - ShapeTree::iterator* iterator, - const ShapeTree::iterator& end, + ShapeTree::iterator* iterator, + const ShapeTree::iterator& end, ExecutionInput* execution_input, se::DeviceMemoryAllocator* allocator) const; diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 8973a3631eccdf..b87bc885903e10 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -1528,7 +1528,7 @@ cc_library( ":computation_layout", ":hlo_module_config", ":hlo_proto_cc", - ":maybe_owning_device_memory", + ":maybe_owning_device_address", ":shaped_buffer", ":stream_pool", "//xla:executable_run_options", @@ -1666,7 +1666,7 @@ cc_library( hdrs = ["transfer_manager.h"], deps = [ ":compiler", - ":maybe_owning_device_memory", + ":maybe_owning_device_address", ":shaped_buffer", "//xla:literal", "//xla:shape_tree", @@ -4460,7 +4460,7 @@ cc_library( ":executable", ":hlo_module_util", ":hlo_runner_interface", - ":maybe_owning_device_memory", + ":maybe_owning_device_address", ":shaped_buffer", ":transfer_manager", "//xla:executable_run_options", diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 986fe761476f80..f9d93965489130 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -550,7 +550,7 @@ cc_library( "//xla/service:hlo_execution_profile", "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_value", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:xla_debug_info_manager", "//xla/stream_executor:device_address", diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index 6bb3a695e9523e..c0c1e6446220fa 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -58,7 +58,7 @@ limitations under the License. #include "xla/service/hlo_execution_profile.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/hlo_value.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/xla_debug_info_manager.h" @@ -156,7 +156,7 @@ CpuExecutable::~CpuExecutable() { } } -static absl::StatusOr MemoryForAllocation( +static absl::StatusOr MemoryForAllocation( const BufferAllocation& allocation, absl::Span arguments, absl::Span constants, @@ -170,17 +170,17 @@ static absl::StatusOr MemoryForAllocation( << "Size mismatch on param " << allocation.parameter_number() << " at shape index " << allocation.param_shape_index().ToString(); VLOG(3) << "allocation is a parameter"; - return MaybeOwningDeviceMemory{out}; + return MaybeOwningDeviceAddress{out}; } else if (allocation.is_constant()) { VLOG(3) << "allocation is a constant"; if (allocation.index() < constants.size()) { - return MaybeOwningDeviceMemory( + return MaybeOwningDeviceAddress( constants[allocation.index()].AsDeviceMemoryBase()); } - return MaybeOwningDeviceMemory{se::DeviceAddressBase{}}; + return MaybeOwningDeviceAddress{se::DeviceAddressBase{}}; } else if (allocation.is_thread_local()) { VLOG(3) << "buffer is thread-local"; - return MaybeOwningDeviceMemory{se::DeviceAddressBase{}}; + return MaybeOwningDeviceAddress{se::DeviceAddressBase{}}; } int64_t buffer_size = allocation.size(); @@ -194,14 +194,14 @@ static absl::StatusOr MemoryForAllocation( // initialized. Mark them initialized so that memory sanitizer doesn't flag // loads from these buffers. ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(out->opaque(), buffer_size); - return MaybeOwningDeviceMemory{std::move(out)}; + return MaybeOwningDeviceAddress{std::move(out)}; } -absl::StatusOr> +absl::StatusOr> CpuExecutable::CreateBufferTable(se::DeviceAddressAllocator* memory_allocator, int device_ordinal, absl::Span arguments) { - std::vector buffers( + std::vector buffers( assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -233,7 +233,7 @@ static int32_t GetDeviceOrdinal(const ExecutableRunOptions* run_options) { absl::Status CpuExecutable::ExecuteThunks( const ExecutableRunOptions* run_options, - absl::Span buffers) { + absl::Span buffers) { uint64_t start_ns = tsl::Env::Default()->NowNanos(); size_t profile_counters_size = 0; @@ -244,7 +244,7 @@ absl::Status CpuExecutable::ExecuteThunks( VLOG(3) << "Executing XLA:CPU thunks:"; VLOG(3) << absl::StrFormat(" Number of buffer allocations: %u", buffers.size()); - auto mem_printer = [](std::string* out, const MaybeOwningDeviceMemory& mem) { + auto mem_printer = [](std::string* out, const MaybeOwningDeviceAddress& mem) { absl::StrAppend(out, absl::StrFormat("%p", mem.AsDeviceMemoryBase().opaque())); }; @@ -308,7 +308,7 @@ absl::Status CpuExecutable::ExecuteThunks( absl::StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - absl::Span buffers, + absl::Span buffers, absl::Span arguments) { se::Stream* stream = run_options->stream(); ExecutionOutput result(/*on_device_shape=*/result_shape(), @@ -345,7 +345,7 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( if (alias) { CHECK_LT(alias->parameter_number, arguments.size()); ExecutionInput& input = arguments[alias->parameter_number]; - MaybeOwningDeviceMemory* maybe_owning_memory = + MaybeOwningDeviceAddress* maybe_owning_memory = input.MutableBuffer(alias->parameter_index); if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) { return InvalidArgument( @@ -381,7 +381,7 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( run_options->allocator()->Allocate( stream->parent()->device_ordinal(), allocation_size)); result_buffer = allocated_buffer.Release(); - MaybeOwningDeviceMemory& registered_buffer = buffers[buffer_index]; + MaybeOwningDeviceAddress& registered_buffer = buffers[buffer_index]; CHECK_EQ(result_buffer.size(), registered_buffer.AsDeviceMemoryBase().size()); std::memcpy(/*dest=*/result_buffer.opaque(), @@ -392,7 +392,7 @@ absl::StatusOr CpuExecutable::CreateResultShapedBuffer( } if (result_buffer.is_null()) { - MaybeOwningDeviceMemory& buffer = buffers[buffer_index]; + MaybeOwningDeviceAddress& buffer = buffers[buffer_index]; if (std::optional> owned_buffer = buffer.Release()) { result_buffer = owned_buffer->Release(); @@ -437,7 +437,7 @@ absl::StatusOr CpuExecutable::ExecuteAsyncOnStream( se::Stream* stream = run_options->stream(); se::DeviceAddressAllocator* memory_allocator = run_options->allocator(); TF_ASSIGN_OR_RETURN( - std::vector buffers, + std::vector buffers, CreateBufferTable(memory_allocator, stream->parent()->device_ordinal(), arguments)); diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index ebb97baf217e47..3db37885900445 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -43,7 +43,7 @@ limitations under the License. #include "xla/service/hlo_execution_profile.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/hlo_value.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/device_address_allocator.h" @@ -72,8 +72,9 @@ class CpuExecutable : public Executable { // Calls emitted thunk sequence with the given arguments using the supplied // buffers. - absl::Status ExecuteThunks(const ExecutableRunOptions* run_options, - absl::Span buffers); + absl::Status ExecuteThunks( + const ExecutableRunOptions* run_options, + absl::Span buffers); absl::Span obj_files() const { return obj_files_; } @@ -172,7 +173,7 @@ class CpuExecutable : public Executable { // // - buffers_to_free: buffers whose ownership was donated by the caller that // are to be freed by the caller. - absl::StatusOr> CreateBufferTable( + absl::StatusOr> CreateBufferTable( se::DeviceAddressAllocator* memory_allocator, int device_ordinal, absl::Span arguments); @@ -182,7 +183,7 @@ class CpuExecutable : public Executable { // assignment. absl::StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - absl::Span buffers, + absl::Span buffers, absl::Span arguments); // Returns the instruction value set of the root instruction of the entry diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc index b52166c243dea1..a9f8da25d12d1c 100644 --- a/third_party/xla/xla/service/executable.cc +++ b/third_party/xla/xla/service/executable.cc @@ -27,7 +27,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -70,7 +70,7 @@ absl::Status ExecutionInput::SetDynamicShape(Shape dynamic_shape) { } void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, - MaybeOwningDeviceMemory buffer) { + MaybeOwningDeviceAddress buffer) { *buffers_.mutable_element(index) = std::move(buffer); unowned_indices_.insert(index); } @@ -86,12 +86,12 @@ absl::StatusOr Executable::ExecuteOnStream( return result; } -static ExecutionInput MakeMaybeOwningDeviceMemoryTree( +static ExecutionInput MakeMaybeOwningDeviceAddressTree( const ShapedBuffer& shaped_buffer) { ExecutionInput result(shaped_buffer.on_device_shape()); shaped_buffer.buffers().ForEachElement( [&](const ShapeIndex& index, const se::DeviceAddressBase& mem) { - result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); + result.SetBuffer(index, MaybeOwningDeviceAddress(mem)); }); return result; } @@ -102,7 +102,7 @@ absl::StatusOr Executable::ExecuteAsyncOnStream( std::vector args; args.reserve(arguments.size()); for (const ShapedBuffer* arg : arguments) { - args.emplace_back(MakeMaybeOwningDeviceMemoryTree(*arg)); + args.emplace_back(MakeMaybeOwningDeviceAddressTree(*arg)); } TF_ASSIGN_OR_RETURN(ExecutionOutput out, ExecuteAsyncOnStream(run_options, std::move(args))); diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index e59ac39a932d44..e76038f8a95f9a 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -37,7 +37,7 @@ limitations under the License. #include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -61,11 +61,11 @@ namespace xla { // 3) Donated by the caller and freed on error. // // Case (1) buffers are stored as -// MaybeOwningDeviceMemory(DeviceAddressBase). Case (2) buffers are -// stored as MaybeOwningDeviceMemory(ScopedDeviceAddress), +// MaybeOwningDeviceAddress(DeviceAddressBase). Case (2) buffers are +// stored as MaybeOwningDeviceAddress(ScopedDeviceAddress), // with their indices present in unowned_indices_. // Case (3) buffers are stored as -// MaybeOwningDeviceMemory(ScopedDeviceAddress), +// MaybeOwningDeviceAddress(ScopedDeviceAddress), // with their indices absent from unowned_indices_. class ExecutionInput { public: @@ -88,14 +88,14 @@ class ExecutionInput { } } - explicit ExecutionInput(ShapeTree buffers) + explicit ExecutionInput(ShapeTree buffers) : buffers_(std::move(buffers)) { if (!ShapeUtil::DeviceShapeIsHostShape(buffers_.shape())) { SetHostShape(ShapeUtil::DeviceShapeToHostShape(buffers_.shape())); } } // TODO(b/170310047): remove this overload. - ExecutionInput(ShapeTree buffers, + ExecutionInput(ShapeTree buffers, xla::Shape host_shape) : buffers_(std::move(buffers)) { if (!ShapeUtil::DeviceShapeIsHostShape(buffers_.shape())) { @@ -119,12 +119,12 @@ class ExecutionInput { absl::Status SetDynamicShape(Shape dynamic_shape); - void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { + void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceAddress buffer) { *buffers_.mutable_element(index) = std::move(buffer); } void SetUnownedBuffer(const ShapeIndex& index, - MaybeOwningDeviceMemory buffer); + MaybeOwningDeviceAddress buffer); void SetUnownedIndex(const ShapeIndex& index) { unowned_indices_.insert(index); @@ -138,15 +138,17 @@ class ExecutionInput { return unowned_indices_; } - const ShapeTree& Buffers() const { return buffers_; } + const ShapeTree& Buffers() const { + return buffers_; + } - ShapeTree* MutableBuffers() { return &buffers_; } + ShapeTree* MutableBuffers() { return &buffers_; } - MaybeOwningDeviceMemory* MutableBuffer(const ShapeIndex& index) { + MaybeOwningDeviceAddress* MutableBuffer(const ShapeIndex& index) { return buffers_.mutable_element(index); } - const MaybeOwningDeviceMemory& Buffer(const ShapeIndex& index) const { + const MaybeOwningDeviceAddress& Buffer(const ShapeIndex& index) const { return buffers_.element(index); } @@ -157,7 +159,7 @@ class ExecutionInput { } } - ShapeTree buffers_; + ShapeTree buffers_; // Set of indices of buffers that should be returned to the caller if an error // occurs when enqueuing the computation. diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 98e1675e56ab9a..b9a929cdf282fe 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -712,7 +712,7 @@ cc_library( "//xla/service:dump", "//xla/service:executable", "//xla/service:hlo_value", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:rendezvous", "//xla/service:shaped_buffer", "//xla/service:stream_pool", diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index f619be71cd9d86..56d8c4cb2f64ee 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -496,7 +496,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:compiler", "//xla/service:executable", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:ir_emission_utils", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc index 14ab52352ad047..290187c503292d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -63,7 +63,7 @@ std::vector ExecutionInputsFromBuffers( // Our executable doesn't have input-output aliasing, so we can pass // unowned input buffers. inputs.back().SetUnownedBuffer( - /*index=*/{}, MaybeOwningDeviceMemory(/*unowned=*/buffers.at(i))); + /*index=*/{}, MaybeOwningDeviceAddress(/*unowned=*/buffers.at(i))); } return inputs; } diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index df3767982ce5f9..905cac3061925c 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -72,7 +72,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo_value.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/rendezvous.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" @@ -937,8 +937,8 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( << " @ index: " << index.ToString(); if (output_info.alias_config) { - MaybeOwningDeviceMemory* maybe_owning_memory = - [&]() -> xla::MaybeOwningDeviceMemory* { + MaybeOwningDeviceAddress* maybe_owning_memory = + [&]() -> xla::MaybeOwningDeviceAddress* { // ScopedBuffer is never an owned buffer. if (std::holds_alternative>( arguments)) { diff --git a/third_party/xla/xla/service/hlo_runner.cc b/third_party/xla/xla/service/hlo_runner.cc index d0e58a65b97d98..077c92bf517de3 100644 --- a/third_party/xla/xla/service/hlo_runner.cc +++ b/third_party/xla/xla/service/hlo_runner.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" @@ -294,7 +294,7 @@ static std::vector ExecutionInputsFromScopedShapedBuffers( for (int param_num = 0; param_num < inputs.size(); param_num++) { const ScopedShapedBuffer& input_buffer = inputs[param_num]; - ShapeTree buffer_tree( + ShapeTree buffer_tree( input_buffer.on_device_shape()); input_buffer.buffers().ForEachElement( @@ -329,7 +329,7 @@ static void ExecutionInputsFromMovedScopedShapedBuffers( for (int param_num = 0; param_num < inputs.size(); param_num++) { ShapedBuffer input_buffer = inputs[param_num].release(); - ShapeTree buffer_tree( + ShapeTree buffer_tree( input_buffer.on_device_shape()); input_buffer.buffers().ForEachElement( diff --git a/third_party/xla/xla/service/transfer_manager.cc b/third_party/xla/xla/service/transfer_manager.cc index da4264b9cb302e..4fbdcdc58ce116 100644 --- a/third_party/xla/xla/service/transfer_manager.cc +++ b/third_party/xla/xla/service/transfer_manager.cc @@ -32,7 +32,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "xla/literal.h" #include "xla/service/compiler.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -288,7 +288,8 @@ absl::Status TransferManager::WriteRootTupleIndexTable( } absl::Status TransferManager::WriteRootTupleIndexTable( - se::Stream* stream, const ShapeTree& buffer_tree) { + se::Stream* stream, + const ShapeTree& buffer_tree) { TF_RET_CHECK(buffer_tree.shape().IsTuple()); if (ShapeUtil::TupleElementCount(buffer_tree.shape()) == 0) { return absl::OkStatus(); diff --git a/third_party/xla/xla/service/transfer_manager.h b/third_party/xla/xla/service/transfer_manager.h index 978bcfc523bfc1..811138ba23f905 100644 --- a/third_party/xla/xla/service/transfer_manager.h +++ b/third_party/xla/xla/service/transfer_manager.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/literal.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" @@ -215,7 +215,7 @@ class TransferManager { const ShapedBuffer& device_buffer); absl::Status WriteRootTupleIndexTable( se::Stream* stream, - const ShapeTree& buffer_tree); + const ShapeTree& buffer_tree); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index c0e4af83f39ddc..1bc9dc62a099ed 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -72,7 +72,7 @@ cc_library( "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", @@ -608,7 +608,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:compiler", "//xla/service:executable", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/stream_executor:device_address", diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index 4f06ee508fd8fb..58eb6c2c3033f9 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_layout.h" @@ -158,14 +158,14 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { return xla_shaped_buffer; } -SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, +SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceAddress& mem, bool aliased) { SE_MaybeOwningDeviceMemory se_mem; se_mem.owned = mem.HasOwnership(); - se_mem.memory = ApiConverter::ToC(mem.AsDeviceMemoryBase()); + se_mem.memory = ApiConverter::ToC(mem.AsDeviceAddress()); if (mem.HasOwnership()) { - const stream_executor::OwningDeviceAddress* owned = - mem.AsOwningDeviceMemory(); + const stream_executor::ScopedDeviceAddress* owned = + mem.AsScopedDeviceAddress(); se_mem.device_ordinal = owned->device_ordinal(); se_mem.allocator = ApiConverter::ToC(owned->allocator()); if (!aliased) { @@ -180,15 +180,15 @@ SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, return se_mem; } -xla::MaybeOwningDeviceMemory FromC( +xla::MaybeOwningDeviceAddress FromC( SE_MaybeOwningDeviceMemory* se_mem, stream_executor::DeviceAddressAllocator* allocator) { if (se_mem->owned) { - return xla::MaybeOwningDeviceMemory(stream_executor::OwningDeviceAddress( + return xla::MaybeOwningDeviceAddress(stream_executor::OwningDeviceAddress( ApiConverter::FromC(se_mem->memory), se_mem->device_ordinal, allocator)); } else { - return xla::MaybeOwningDeviceMemory(ApiConverter::FromC(se_mem->memory)); + return xla::MaybeOwningDeviceAddress(ApiConverter::FromC(se_mem->memory)); } } diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h index a3b7c716996b34..da3db36c17a1d2 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h @@ -28,7 +28,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/literal.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -117,7 +117,7 @@ struct TpuEmbeddingEngineParametersData { std::unique_ptr Create(int num_tables); -xla::MaybeOwningDeviceMemory FromC( +xla::MaybeOwningDeviceAddress FromC( SE_MaybeOwningDeviceMemory* se_mem, stream_executor::DeviceAddressAllocator* allocator); @@ -132,7 +132,8 @@ SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceAddress* mem); // mem.HasOwnership() may be true if the buffer is aliased and shouldn't be // released. 'aliased' should be true in this case. 'aliased' has no effect if // 'mem' is unowned. -SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased); +SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceAddress& mem, + bool aliased); // HloModule XLA_HloModule ToC(const xla::HloModule& module); diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 096f265acaec79..834a3da9f4ed0d 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -277,10 +277,10 @@ typedef struct XLA_Literal { XLA_Shape shape; } XLA_Literal; -typedef struct XLA_MaybeOwningDeviceMemoryShapeTree { +typedef struct XLA_MaybeOwningDeviceAddressShapeTree { XLA_Shape shape; SE_MaybeOwningDeviceMemory* buffers; -} XLA_MaybeOwningDeviceMemoryShapeTree; +} XLA_MaybeOwningDeviceAddressShapeTree; typedef struct XLA_ShapeIndex { int64_t indices[8]; @@ -288,7 +288,7 @@ typedef struct XLA_ShapeIndex { } XLA_ShapeIndex; typedef struct SE_ExecutionInput { - XLA_MaybeOwningDeviceMemoryShapeTree shape_tree; + XLA_MaybeOwningDeviceAddressShapeTree shape_tree; XLA_ShapeIndex* unowned_indices; int unowned_indices_size; XLA_Shape dynamic_shape; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc index 0b4c4db98728d2..ab8616ddc8ecc4 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" @@ -106,7 +106,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( -> absl::Status { if (alias && alias->must_alias()) { VLOG(1) << alias->ToString(); - const MaybeOwningDeviceMemory& original_input = + const MaybeOwningDeviceAddress& original_input = (*arguments)[alias->parameter_number].Buffers().element( alias->parameter_index); if (!original_input.HasOwnership()) { @@ -152,7 +152,7 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse( if (alias) { TF_RET_CHECK(alias->parameter_number < arguments->size()); ExecutionInput& input = (*arguments)[alias->parameter_number]; - MaybeOwningDeviceMemory* device_memory = + MaybeOwningDeviceAddress* device_memory = input.MutableBuffer(alias->parameter_index); if (auto owning = device_memory->Release()) { // If the caller passes the ownership of the device memory, reuse it diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index aff7b7e1abfcdd..9f617478a6ea7b 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -520,7 +520,7 @@ xla_test( "//xla/service:backend", "//xla/service:executable", "//xla/service:hlo_module_config", - "//xla/service:maybe_owning_device_memory", + "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 150e6c769ace79..324917cbd57df6 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/service/backend.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/maybe_owning_device_address.h" #include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -132,10 +132,11 @@ class BufferDonationTest : public HloTestBase { stream.get(), argument_literal, shaped_buffer)); ShapeTree input_buffers = shaped_buffer.buffers(); inputs_buffers.push_back(input_buffers); - ShapeTree owned_buffers( + ShapeTree owned_buffers( argument_literal.shape()); owned_buffers.ForEachMutableElement( - [&](const ShapeIndex& index, MaybeOwningDeviceMemory* device_memory) { + [&](const ShapeIndex& index, + MaybeOwningDeviceAddress* device_memory) { if (donate_argument) { *device_memory = se::OwningDeviceMemory( input_buffers.element(index), executor_->device_ordinal(), From 232d54ce017655b3a085fccb98fcf195b5fcf739 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 12:56:24 -0800 Subject: [PATCH 044/753] Reverts a243fdc6b63024c7d71a4cef1b841003a8f408c2 PiperOrigin-RevId: 841881595 --- .../simplifiers/algebraic_simplifier.cc | 79 ------------------- .../simplifiers/algebraic_simplifier_test.cc | 60 -------------- 2 files changed, 139 deletions(-) diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index f2368abca9f8a0..413520c0f2ab48 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -6899,85 +6899,6 @@ absl::Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } } - // Simplify: - // Txx[...,1] slice(Txx[..., K] reshape(Txx[...,N*K])) // N > 1 - // To: - // Txx[...,1] reshape(Txx[..., N] slice(Txx[...,N*K], stride(-1)=K) - // - // Maintaining data-parallelism to improve throughput on some architectures. - HloInstruction* reshape; - if (Match(slice, m::Slice(m::Reshape(&reshape, m::Op())))) { - HloInstruction* input = reshape->mutable_operand(0); - const Shape& input_shape = input->shape(); - const Shape& reshape_shape = reshape->shape(); - - const int64_t input_rank = input_shape.dimensions().size(); - const int64_t reshape_rank = reshape_shape.dimensions().size(); - const int64_t slice_rank = slice->shape().dimensions().size(); - - // Reshape must have at least 2 dimensions and same number of - // dimensions as slice. - if (reshape_rank >= 2 && reshape_rank == slice_rank) { - bool is_valid_reshape_slice = true; - for (int64_t i = 0; i < slice_rank; ++i) { - if (i == slice_rank - 1) { - // Continue if we are slicing exactly one element from the last - // dimension. - if (slice->slice_limits(i) - slice->slice_starts(i) == 1) { - continue; - } - } else { - // Continue if we are not slicing any other dimension. - if (slice->slice_starts(i) == 0 && - slice->slice_limits(i) == reshape_shape.dimensions(i) && - slice->slice_strides(i) == 1) { - continue; - } - } - // If the rules above are not met, prevent a match. - is_valid_reshape_slice = false; - break; - } - - // Check if slice is selecting a single element from the last dimension. - if (is_valid_reshape_slice) { - int64_t slice_index = slice->slice_starts()[slice_rank - 1]; - int64_t K = reshape_shape.dimensions(reshape_rank - 1); - - // Check if input shape can be viewed as [..., N*K], where N is two or - // more, e.g. Input [1, 2024, 4, 128], Reshape [518144, 2]. - // Last dim of input 128 is multiple of 2. - if (!input_shape.dimensions().empty()) { - int64_t last_dim = input_shape.dimensions(input_rank - 1); - if (last_dim % K == 0 && last_dim / K > 1) { - // It matches! - DimensionVector starts(input_rank, 0); - DimensionVector limits(input_shape.dimensions().begin(), - input_shape.dimensions().end()); - DimensionVector strides(input_rank, 1); - - starts[input_rank - 1] = slice_index; - limits[input_rank - 1] = last_dim; - strides[input_rank - 1] = K; - - Shape new_slice_shape = input_shape; - new_slice_shape.set_dimensions( - input_rank - 1, input_shape.dimensions(input_rank - 1) / K); - simplifier_->UpdateLayout(&new_slice_shape); - - HloInstruction* new_slice = - slice->parent()->AddInstruction(HloInstruction::CreateSlice( - new_slice_shape, input, starts, limits, strides)); - HloInstruction* new_reshape = slice->parent()->AddInstruction( - HloInstruction::CreateReshape(slice->shape(), new_slice)); - - return ReplaceInstruction(slice, new_reshape); - } - } - } - } - } - if (slice->operand(0)->opcode() == HloOpcode::kSlice && hlo_instruction_utils::IsUnstridedSlice(slice) && hlo_instruction_utils::IsUnstridedSlice(slice->operand(0))) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index cecf7762332bbd..34bea124e05379 100644 --- a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -5299,66 +5299,6 @@ TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) { EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4); } -TEST_F(AlgebraicSimplifierTest, SliceWithReshape) { - const absl::string_view hlo_string = R"hlo( - HloModule SliceWithReshape - - ENTRY main { - %arg = f32[1,2024,4,128]{3,2,1,0} parameter(0) - %reshape.1 = f32[2,259072,2]{2,1,0} reshape(%arg) - %slice = f32[2,259072,1]{2,1,0} slice(%reshape.1), slice={[0:2], [0:259072], [1:2]} - ROOT %reshape.2 = f32[518144]{0} reshape(%slice) - } -)hlo"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(module.get()).value()); - - auto* root = module->entry_computation()->root_instruction(); - VLOG(2) << module->ToString(); - - // Expected: Reshape(Slice(Arg)) - // AlgebraicSimplifier merges the two reshapes. - const HloInstruction* slice; - EXPECT_THAT(root, GmockMatch(m::Reshape( - m::Slice(&slice, m::Parameter(0))))); - - EXPECT_EQ(slice->slice_strides(3), 2); - EXPECT_EQ(slice->slice_starts(3), 1); - EXPECT_EQ(slice->slice_limits(3), 128); - EXPECT_EQ(slice->shape().dimensions(3), 64); -} - -TEST_F(AlgebraicSimplifierTest, SmallSliceWithReshape) { - const absl::string_view hlo_string = R"hlo( - HloModule SliceWithReshape - - ENTRY main { - %arg = f32[2]{0} parameter(0) - %reshape.1 = f32[2,1]{1,0} reshape(%arg) - %slice = f32[1,1]{1,0} slice(%reshape.1), slice={[0:1], [0:1]} - ROOT %reshape.2 = f32[1]{0} reshape(%slice) - } -)hlo"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(module.get()).value()); - - auto* root = module->entry_computation()->root_instruction(); - LOG(INFO) << module->ToString(); - - // Expected: Reshape(Slice(Arg)) - // AlgebraicSimplifier merges the two reshapes. - const HloInstruction* slice; - EXPECT_THAT(root, GmockMatch(m::Reshape( - m::Slice(&slice, m::Parameter(0))))); - - EXPECT_EQ(slice->slice_strides(0), 1); - EXPECT_EQ(slice->slice_starts(0), 0); - EXPECT_EQ(slice->slice_limits(0), 1); - EXPECT_EQ(slice->shape().dimensions(0), 1); -} - TEST_F(AlgebraicSimplifierTest, SliceOfBroadcastToBroadcast) { HloComputation::Builder builder(TestName()); const int64_t dim0 = 11; From b198f87cb214c8e52c39a57b374e5ba320c9804c Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 8 Dec 2025 13:19:09 -0800 Subject: [PATCH 045/753] [XLA:CPU/GPU] Emit Arith::NegFOp in tiled emitter. This gives better numerical stability on CPU which does support the instruction, we simply rewrite it back to it's original form to ensure the triton lowering works. PiperOrigin-RevId: 841890088 --- .../codegen/triton/compilation_pipeline.cc | 1 + .../gpu/codegen/triton/emitter_helpers.cc | 6 +- .../triton/fusion_emitter_device_test.cc | 8 +- .../backends/gpu/codegen/triton/support.cc | 3 +- .../gpu/codegen/triton/support_test.cc | 3 +- .../gpu/codegen/triton/transforms/BUILD | 2 + .../gpu/codegen/triton/transforms/passes.h | 1 + .../gpu/codegen/triton/transforms/passes.td | 8 ++ ...nsupported_elementwise_to_triton_pass.mlir | 20 +++++ .../unsupported_elementwise_to_triton_pass.cc | 83 +++++++++++++++++++ 10 files changed, 127 insertions(+), 8 deletions(-) create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/unsupported_elementwise_to_triton_pass.mlir create mode 100644 third_party/xla/xla/backends/gpu/codegen/triton/transforms/unsupported_elementwise_to_triton_pass.cc diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline.cc b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline.cc index b23a7dfe498e48..67c1eb9b94d7a7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/compilation_pipeline.cc @@ -40,6 +40,7 @@ void CreateTritonXlaPipeline( pm->addPass(mlir::triton::xla::CreateStableHLOLowerToTritonPass()); pm->addPass(emitters::CreateSafeIntegerArithmeticPass()); + pm->addPass(mlir::triton::xla::CreateUnsupportedElementwiseToTritonPass()); auto* cuda_cc = gpu_cc.cuda_compute_capability(); bool is_at_least_hopper = cuda_cc != nullptr && cuda_cc->IsAtLeastHopper(); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index f73c72bcf7873a..19c2bed34cf55c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -439,8 +439,10 @@ absl::StatusOr EmitElementwise(mlir::ImplicitLocOpBuilder& b, case HloOpcode::kNot: return ma::XOrIOp::create(b, inputs[0], OnesLike(b, inputs[0].getType())); case HloOpcode::kNegate: - // NegFOp is not supported by Triton. - return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); + if (is_integer) { + return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); + } + return ma::NegFOp::create(b, inputs[0]); case HloOpcode::kConvert: { TF_ASSIGN_OR_RETURN( Type dst_ty, PrimitiveTypeToMlirType(b, hlo.shape().element_type())); diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc index 99112cd3bf51b3..d93e3016930055 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc @@ -3109,10 +3109,10 @@ ENTRY entry_computation { CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( CHECK: xtile.extract {{.*}} -> tensor CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 -CHECK: arith.subf {{.*}} f32 +CHECK: arith.negf {{.*}} f32 CHECK: xtile.extract {{.*}} -> tensor CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 -CHECK: arith.subf {{.*}} f32 +CHECK: arith.negf {{.*}} f32 CHECK: arith.addf {{.*}} f32 CHECK: arith.mulf {{.*}} f32 CHECK: arith.divf {{.*}} f32 @@ -3622,7 +3622,7 @@ CHECK: {{.*}} = scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] CHECK-SAME: iter_args({{.*}}) -> (tensor<16x64xf32>) { CHECK-DAG: xtile.extract %[[ARG0]] CHECK-DAG: xtile.extract %[[ARG1]] -CHECK-DAG: arith.subf {{.*}} : tensor<16x32xf32> +CHECK-DAG: arith.negf {{.*}} : tensor<16x32xf32> CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32> CHECK: stablehlo.dot_general {{.*}} (tensor<16x32xf32>, tensor<32x64xf32>) -> tensor<16x64xf32> CHECK: arith.addf {{.*}} @@ -3643,7 +3643,7 @@ CHECK: {{.*}} = scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] CHECK-SAME: iter_args({{.*}}) -> (tensor<16x64xf32>) { CHECK-DAG: xtile.extract %[[ARG0]] CHECK-DAG: xtile.extract %[[ARG1]] -CHECK-DAG: arith.subf {{.*}} : tensor<16x32xf32> +CHECK-DAG: arith.negf {{.*}} : tensor<16x32xf32> CHECK-DAG: math.absf {{.*}} : tensor<32x64xf32> CHECK: tt.dot {{.*}} tensor<16x32xf32> * tensor<32x64xf32> -> tensor<16x64xf32> CHECK: scf.yield {{.*}} : tensor<16x64xf32> diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc index de63f5dd669128..2f044f7f98afa7 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support.cc @@ -92,7 +92,8 @@ absl::flat_hash_set TritonSupportedUnaryElementwiseOps( absl::flat_hash_set ret{HloOpcode::kAbs, HloOpcode::kCopy}; if (element_type != PrimitiveType::F8E5M2 && - element_type != PrimitiveType::F8E4M3FN) { + element_type != PrimitiveType::F8E4M3FN && + element_type != PrimitiveType::F8E8M0FNU) { ret.insert(HloOpcode::kNegate); } diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc index bbb7ac9a3e931a..7cda77ed4e673a 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/support_test.cc @@ -123,9 +123,10 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { case HloOpcode::kDivide: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kNegate: case HloOpcode::kIota: return type != PRED; + case HloOpcode::kNegate: + return type != PRED && type != F8E8M0FNU; case HloOpcode::kRng: return !pu::IsComplexType(type); case HloOpcode::kComplex: diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD index 5cc3e4f9c12505..9a722714a8357f 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/BUILD @@ -49,6 +49,7 @@ cc_library( "triton_xla_math_to_libdevice.cc", "triton_xla_squeeze_dims_pass.cc", "triton_xla_unswitch_loops_pass.cc", + "unsupported_elementwise_to_triton_pass.cc", "xtile_lower_to_triton.cc", ], hdrs = ["passes.h"], @@ -82,6 +83,7 @@ cc_library( "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.h b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.h index 2ab433be265408..75007131ffdb13 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.h +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.h @@ -54,6 +54,7 @@ std::unique_ptr CreateTritonXLAMathToLibdevicePass( absl::string_view libdevice_path, absl::string_view triple); std::unique_ptr CreateXTileLowerToTritonPass(); std::unique_ptr CreateArithFP8ConversionToTritonPass(); +std::unique_ptr CreateUnsupportedElementwiseToTritonPass(); // Returns true if the `op` contains an operation in it's regions that satisfies // the `fn`. diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td index 1c4d71feb98e1b..d8779d2ba0f4ce 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/passes.td @@ -264,5 +264,13 @@ def ArithFP8ConversionToTritonPass "::mlir::triton::TritonDialect", ]; } +def UnsupportedElementwiseToTritonPass + : Pass<"unsupported-elementwise-to-triton"> { + let summary = + "Converts unsupported elementwise operations to their Triton equivalent."; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + ]; +} #endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_PASSES_TD_ diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/unsupported_elementwise_to_triton_pass.mlir b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/unsupported_elementwise_to_triton_pass.mlir new file mode 100644 index 00000000000000..8313bbc324ae06 --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/unsupported_elementwise_to_triton_pass.mlir @@ -0,0 +1,20 @@ +// RUN: xla-opt %s -split-input-file -unsupported-elementwise-to-triton \ +// RUN: | FileCheck %s + +func.func @converts_tensor_negf_to_subf(%arg0: tensor<10xf32>) -> tensor<10xf32> { + // CHECK: %[[ZERO:.*]] = arith.constant dense<0.000000e+00> : tensor<10xf32> + // CHECK: %[[SUB:.*]] = arith.subf %[[ZERO]], %arg0 : tensor<10xf32> + %0 = arith.negf %arg0 : tensor<10xf32> + // CHECK: return %[[SUB]] : tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +//----- + +func.func @converts_scalar_negf_to_subf(%arg0: f32) -> f32 { + // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[SUB:.*]] = arith.subf %[[ZERO]], %arg0 : f32 + %0 = arith.negf %arg0 : f32 + // CHECK: return %[[SUB]] : f32 + func.return %0 : f32 +} diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/transforms/unsupported_elementwise_to_triton_pass.cc b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/unsupported_elementwise_to_triton_pass.cc new file mode 100644 index 00000000000000..fabfc8d9815c3b --- /dev/null +++ b/third_party/xla/xla/backends/gpu/codegen/triton/transforms/unsupported_elementwise_to_triton_pass.cc @@ -0,0 +1,83 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/backends/gpu/codegen/triton/transforms/passes.h" + +namespace mlir::triton::xla { + +#define GEN_PASS_DEF_UNSUPPORTEDELEMENTWISETOTRITONPASS +#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc" + +namespace { + +class RewriteNegFToSubtract : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::NegFOp op, + PatternRewriter& rewriter) const override { + mlir::Type element_type = getElementTypeOrSelf(op.getType()); + auto type = mlir::dyn_cast(element_type); + + if (!type) { + return rewriter.notifyMatchFailure(op, "expected float type"); + } + + const llvm::fltSemantics& semantics = type.getFloatSemantics(); + mlir::Value zero_value = + mlir::createScalarOrSplatConstant(rewriter, op->getLoc(), op.getType(), + mlir::APFloat::getZero(semantics)); + + rewriter.replaceOpWithNewOp(op, zero_value, + op.getOperand()); + return success(); + } +}; + +struct UnsupportedElementwiseToTritonPass + : public impl::UnsupportedElementwiseToTritonPassBase< + UnsupportedElementwiseToTritonPass> { + void runOnOperation() override { + auto module = getOperation(); + mlir::RewritePatternSet patterns( + &getContext(), std::make_unique(&getContext())); + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr CreateUnsupportedElementwiseToTritonPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::xla From 5380ac8c2e7b8f20046d1115eee73d0cc4e464e3 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Mon, 8 Dec 2025 13:42:05 -0800 Subject: [PATCH 046/753] Use absl::StrAppend for string building. This change - Replaces string concatenation with absl::StrAppend for efficiency - Updates usages of se::DeviceMemoryBase to se::DeviceAddressBase, reflecting a type renaming. PiperOrigin-RevId: 841899098 --- .../gpu/runtime/dynamic_slice_thunk.cc | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc index 5628f682b31981..23227934cd3775 100644 --- a/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/dynamic_slice_thunk.cc @@ -118,15 +118,17 @@ std::string DynamicSliceThunk::SliceDef::ToString() const { // embedded_thunk_argument if (embedded_thunk_argument.has_value()) { - result += "embedded_thunk_argument:" + embedded_thunk_argument->ToString(); + absl::StrAppend(&result, "embedded_thunk_argument:", + embedded_thunk_argument->ToString()); } else { - result += "embedded_thunk_argument:null"; + absl::StrAppend(&result, "embedded_thunk_argument:null"); } // offsets if (offsets.has_value()) { - result += ", offsets:["; - result += + absl::StrAppend(&result, ", offsets:["); + absl::StrAppend( + &result, absl::StrJoin(*offsets, ", ", [](std::string* out, const auto& offset) { std::visit( [out](const auto& value) { @@ -141,34 +143,34 @@ std::string DynamicSliceThunk::SliceDef::ToString() const { } }, offset); - }); - result += "]"; + })); + absl::StrAppend(&result, "]"); } else { - result += ", offsets:null"; + absl::StrAppend(&result, ", offsets:null"); } // orig_shape if (orig_shape.has_value()) { - result += ", orig_shape:" + orig_shape->ToString(); + absl::StrAppend(&result, ", orig_shape:", orig_shape->ToString()); } else { - result += ", orig_shape:null"; + absl::StrAppend(&result, ", orig_shape:null"); } // sliced_shape if (sliced_shape.has_value()) { - result += ", sliced_shape:" + sliced_shape->ToString(); + absl::StrAppend(&result, ", sliced_shape:", sliced_shape->ToString()); } else { - result += ", sliced_shape:null"; + absl::StrAppend(&result, ", sliced_shape:null"); } // offset_byte_size if (offset_byte_size.has_value()) { - result += ", offset_byte_size:" + absl::StrCat(*offset_byte_size); + absl::StrAppend(&result, ", offset_byte_size:", *offset_byte_size); } else { - result += ", offset_byte_size:null"; + absl::StrAppend(&result, ", offset_byte_size:null"); } - result += "}"; + absl::StrAppend(&result, "}"); return result; } @@ -243,7 +245,7 @@ absl::Status DynamicSliceThunk::Prepare(const PrepareParams& params) { HloEvaluator() .Evaluate( /*module=*/*offset_as_function_of_indvar_metadata_->indvar_init, - /*arg_literals=*/{}) + /*args=*/{}) .value(); VLOG(2) << "Indvar init module: " << offset_as_function_of_indvar_metadata_->indvar_init->ToString(); From 72b1ad96d11b5cb8278667a57d5e8255ff82cfe5 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 8 Dec 2025 13:45:56 -0800 Subject: [PATCH 047/753] [XLA:CPU][XTile] Wrap copy operation if tiling enabled. This is significantly quicker than the existing implementation when tiled emitter is enabled (due to it being multithreaded) PiperOrigin-RevId: 841900416 --- third_party/xla/xla/backends/cpu/codegen/tiled/BUILD | 5 ++++- .../xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.cc | 4 ++-- .../xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h | 2 ++ .../cpu/codegen/tiled/tiled_fusion_emitter_stub.cc | 2 ++ .../xla/backends/cpu/testlib/kernel_runner_extension.cc | 3 ++- third_party/xla/xla/service/cpu/BUILD | 1 + third_party/xla/xla/service/cpu/cpu_compiler.cc | 4 +++- third_party/xla/xla/service/cpu/fusion_wrapper.cc | 7 +++++++ third_party/xla/xla/service/cpu/fusion_wrapper.h | 6 ++++-- third_party/xla/xla/service/cpu/fusion_wrapper_test.cc | 2 +- 10 files changed, 28 insertions(+), 8 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/BUILD b/third_party/xla/xla/backends/cpu/codegen/tiled/BUILD index 0e0243f5b6dbd5..1ae35c6a56c63d 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/BUILD +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/BUILD @@ -34,7 +34,10 @@ cc_library( ["tiled_fusion_emitter_stub.cc"], ), hdrs = ["tiled_fusion_emitter.h"], - visibility = ["//xla/backends/cpu/codegen:__pkg__"], + visibility = [ + "//xla/backends/cpu/codegen:__pkg__", + "//xla/service/cpu:__pkg__", + ], deps = [ "//xla:shape_util", "//xla:util", diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.cc b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.cc index ca968b3df04a1d..74f951b2ef9d3b 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.cc +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.cc @@ -113,7 +113,7 @@ absl::StatusOr> GetTiling( } // We don't currently support sub-byte types in the tiled CPU emitter. -static bool IsSupportedType(PrimitiveType type) { +bool IsSupportedTilingType(PrimitiveType type) { if (type == PRED) { return true; } @@ -144,7 +144,7 @@ static bool IsSupportedShape(const Shape& shape) { ShapeUtil::ForEachSubshape( shape, [&](const Shape& subshape, const ShapeIndex& index) { if (subshape.IsArray()) { - if (!IsSupportedType(subshape.element_type())) { + if (!IsSupportedTilingType(subshape.element_type())) { is_supported = false; } } diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h index d2f88d17d85b74..6a8eaaf96b7ac2 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h @@ -32,6 +32,8 @@ limitations under the License. namespace xla::cpu { +bool IsSupportedTilingType(PrimitiveType type); + absl::StatusOr> GetTilingIfSupported( mlir::MLIRContext& context, const HloFusionInstruction& fusion); diff --git a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter_stub.cc b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter_stub.cc index 37f2abadb37ce4..87e14516c1d740 100644 --- a/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter_stub.cc +++ b/third_party/xla/xla/backends/cpu/codegen/tiled/tiled_fusion_emitter_stub.cc @@ -30,6 +30,8 @@ limitations under the License. namespace xla::cpu { +bool IsSupportedTilingType(PrimitiveType type) { return false; } + absl::StatusOr> GetTilingIfSupported( mlir::MLIRContext& context, const HloFusionInstruction& fusion) { return absl::UnimplementedError("not supported for this build configuration"); diff --git a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc index 20385e34237e0d..e3646f2fd624b2 100644 --- a/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc +++ b/third_party/xla/xla/backends/cpu/testlib/kernel_runner_extension.cc @@ -271,7 +271,8 @@ NB_MODULE(_extension, kernel_runner_module) { kernel_runner_module.def( "run_fusion_wrapper_pass", [](std::unique_ptr> hlo_module) { - FusionWrapper fusion_wrapper(true); + FusionWrapper fusion_wrapper(/*using_new_fusion_emitter=*/true, + /*use_tiled_emitter=*/true); absl::StatusOr result = fusion_wrapper.Run(hlo_module.get()); if (!result.ok()) { throw std::runtime_error(std::string(result.status().message())); diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index f9d93965489130..7154ec3e7ff9c5 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -1097,6 +1097,7 @@ cc_library( srcs = ["fusion_wrapper.cc"], hdrs = ["fusion_wrapper.h"], deps = [ + "//xla/backends/cpu/codegen/tiled:tiled_fusion_emitter", "//xla/codegen/emitters:fusion_wrapper_base", "//xla/hlo/ir:hlo", "@com_google_absl//absl/strings:string_view", diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index a6117a8169ddc3..12b1e38459b79e 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -1017,7 +1017,9 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( if (is_fusion_emitters) { bool use_experimental_loop_fusion = options::UseExperimentalLoopFusion(module->config()); - pipeline.AddPass(use_experimental_loop_fusion); + bool use_tiled_emitter = options::EnableTiledEmitter(module->config()); + pipeline.AddPass(use_experimental_loop_fusion, + use_tiled_emitter); } AliasInfo alias_info; diff --git a/third_party/xla/xla/service/cpu/fusion_wrapper.cc b/third_party/xla/xla/service/cpu/fusion_wrapper.cc index af4cf569643a95..1bef382eff71fe 100644 --- a/third_party/xla/xla/service/cpu/fusion_wrapper.cc +++ b/third_party/xla/xla/service/cpu/fusion_wrapper.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/cpu/fusion_wrapper.h" +#include "xla/backends/cpu/codegen/tiled/tiled_fusion_emitter.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -85,6 +86,12 @@ bool FusionWrapper::MustWrapInstruction(const HloInstruction& instruction) { case HloOpcode::kTanh: case HloOpcode::kXor: return using_new_fusion_emitter_; + case HloOpcode::kCopy: + if (use_tiled_emitter_) { + PrimitiveType type = instruction.shape().element_type(); + return IsSupportedTilingType(type); + } + return false; // The following ops are supported but the performance is not as good as the // non-fusion path. // TODO(willfroom): Remove this once the performance is improved. diff --git a/third_party/xla/xla/service/cpu/fusion_wrapper.h b/third_party/xla/xla/service/cpu/fusion_wrapper.h index 5f430f93afa8c7..5da07c2f3efc3f 100644 --- a/third_party/xla/xla/service/cpu/fusion_wrapper.h +++ b/third_party/xla/xla/service/cpu/fusion_wrapper.h @@ -28,8 +28,9 @@ namespace cpu { // kick in. class FusionWrapper : public emitters::FusionWrapperBase { public: - explicit FusionWrapper(bool using_new_fusion_emitter) - : using_new_fusion_emitter_(using_new_fusion_emitter) {} + explicit FusionWrapper(bool using_new_fusion_emitter, bool use_tiled_emitter) + : using_new_fusion_emitter_(using_new_fusion_emitter), + use_tiled_emitter_(use_tiled_emitter) {} ~FusionWrapper() override = default; absl::string_view name() const override { return "fusion-wrapper"; } @@ -38,6 +39,7 @@ class FusionWrapper : public emitters::FusionWrapperBase { private: bool using_new_fusion_emitter_; + bool use_tiled_emitter_; }; } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/fusion_wrapper_test.cc b/third_party/xla/xla/service/cpu/fusion_wrapper_test.cc index b8e1438ef1dc34..c81369604eb756 100644 --- a/third_party/xla/xla/service/cpu/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/cpu/fusion_wrapper_test.cc @@ -56,7 +56,7 @@ TEST_F(FusionWrapperTest, Scatter) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, ParseAndReturnVerifiedModule(hlo_string)); - FusionWrapper wrapper(false); + FusionWrapper wrapper(false, false); TF_ASSERT_OK_AND_ASSIGN(bool changed, wrapper.Run(m.get())); EXPECT_TRUE(changed); From 5854d191dc17b477b4efc7228160f3febcfd72a6 Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Mon, 8 Dec 2025 13:53:25 -0800 Subject: [PATCH 048/753] [XLA] Add methods to permute the operands of a fusion op. PiperOrigin-RevId: 841903126 --- third_party/xla/xla/hlo/ir/hlo_computation.cc | 30 ++++++++++++++++ third_party/xla/xla/hlo/ir/hlo_computation.h | 4 +++ .../xla/xla/hlo/ir/hlo_instructions.cc | 27 ++++++++++++++ third_party/xla/xla/hlo/ir/hlo_instructions.h | 4 +++ third_party/xla/xla/service/BUILD | 1 + .../xla/xla/service/hlo_instruction_test.cc | 35 +++++++++++++++++++ 6 files changed, 101 insertions(+) diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 57989017c4e43a..50d5655aeab3ab 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -593,6 +593,36 @@ absl::Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) { return absl::OkStatus(); } +absl::Status HloComputation::PermuteParameters( + absl::Span permutation) { + if (permutation.size() != num_parameters()) { + return absl::InvalidArgumentError( + "Permutation size must match the number of parameters."); + } + if (permutation.size() == 1) { + return absl::OkStatus(); + } + + std::vector> new_param_instructions( + num_parameters()); + for (int64_t i = 0; i < num_parameters(); ++i) { + int64_t new_param_number = permutation[i]; + new_param_instructions[new_param_number] = HloInstruction::CreateParameter( + new_param_number, param_instructions_[i]->shape(), + param_instructions_[i]->name()); + } + + for (int64_t i = 0; i < num_parameters(); ++i) { + ReplaceParameter(i, std::move(new_param_instructions[permutation[i]])); + } + + absl::c_sort(param_instructions_, + [](const HloInstruction* a, const HloInstruction* b) { + return a->parameter_number() < b->parameter_number(); + }); + return absl::OkStatus(); +} + bool HloComputation::IsSafelyRemovable( const HloInstruction* instruction, bool ignore_control_dependency, std::optional< diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index f13f1a8a937aa0..64f3ca34d84a98 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -984,6 +984,10 @@ class HloComputation { void ClearCalledComputations(); + // Permutes the parameter numbers of this computation according to the + // provided permutation. + absl::Status PermuteParameters(absl::Span permutation); + private: friend class HloModule; diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 55c3b05fe0d10f..526165754c9794 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -2597,6 +2597,33 @@ absl::Status HloFusionInstruction::DeduplicateFusionOperands() { return absl::OkStatus(); } +absl::Status HloFusionInstruction::PermuteFusionOperands( + absl::Span permutation) { + if (permutation.size() != operand_count()) { + return absl::InvalidArgumentError( + "Permutation size must match the number of operands."); + } + std::vector seen(permutation.size(), false); + for (int64_t i = 0; i < permutation.size(); ++i) { + if (permutation[i] < 0 || permutation[i] >= operand_count() || + seen[permutation[i]]) { + return absl::InvalidArgumentError( + "Argument is not a permutation of operand indices."); + } + seen[permutation[i]] = true; + } + + TF_RETURN_IF_ERROR( + fused_instructions_computation()->PermuteParameters(permutation)); + InstructionVector new_operands(operand_count()); + for (int64_t i = 0; i < operand_count(); ++i) { + new_operands[permutation[i]] = mutable_operand(i); + } + RemoveAllOperands(); + AppendOperands(new_operands); + return absl::OkStatus(); +} + HloCallInstruction::HloCallInstruction(const Shape& shape, HloInstruction* called_computation_root) : HloCallableInstruction(HloOpcode::kCall, shape) { diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index 7c6753c80071a6..b73f54e0b830f9 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -1588,6 +1588,10 @@ class HloFusionInstruction : public HloCallableInstruction { // If multiple operands are the same instruction, keeps only one of them. absl::Status DeduplicateFusionOperands(); + // Permutes the operands computation according to the provided permutation. + // The fusion computation is also adjusted accordingly. + absl::Status PermuteFusionOperands(absl::Span permutation); + static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kFusion; } diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b87bc885903e10..52e39430e85f9d 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -925,6 +925,7 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index fddfda34a81e12..aa90adc01f6e8a 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/proto/proto_matchers.h" #include "xla/util.h" #include "xla/window_util.h" @@ -3381,5 +3382,39 @@ TEST_F(HloInstructionTest, DifferentResultAccuracy) { EXPECT_FALSE(exp1->equal_result_accuracy(exp2)); } +TEST_F(HloInstructionTest, FusionPermuteOperandsTest) { + constexpr char kHloString[] = R"( + HloModule test_module + fusion_computation { + p0 = f32[] parameter(0) + p1 = f32[32] parameter(1) + p2 = f32[32,32] parameter(2) + bcast0 = f32[32,32] broadcast(p0), dimensions={} + bcast1 = f32[32,32] broadcast(p1), dimensions={0} + sub = f32[32,32] subtract(bcast0, bcast1) + ROOT add = f32[32,32] add(sub, p2) + } + + ENTRY reduce { + p0 = f32[] parameter(0) + p1 = f32[32] parameter(1) + p2 = f32[32,32] parameter(2) + ROOT root = f32[32,32] fusion(p0, p1, p2), kind=kLoop, calls=fusion_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + HloFusionInstruction* fusion = Cast( + module->entry_computation()->root_instruction()); + EXPECT_OK(fusion->PermuteFusionOperands({1, 2, 0})); + + EXPECT_THAT(fusion, GmockMatch(m::Fusion(m::Parameter(2), m::Parameter(0), + m::Parameter(1)))); + HloComputation* fusion_computation = fusion->fused_instructions_computation(); + EXPECT_THAT(fusion_computation->root_instruction(), + GmockMatch(m::Add(m::Subtract(m::Broadcast(m::Parameter(1)), + m::Broadcast(m::Parameter(2))), + m::Parameter(0)))); +} + } // namespace } // namespace xla From c78a6691190e8a5428f54515888123fc2e63c50b Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Mon, 8 Dec 2025 14:16:03 -0800 Subject: [PATCH 049/753] Move helper functions to an anonymous namespace. This change moves `getScalarLimitOfFloatType` and `getScalarLimitOfIntegerType` into an anonymous namespace to limit their visibility to the current translation unit. PiperOrigin-RevId: 841912322 --- .../xla/xla/mlir_hlo/utils/hlo_utils.cc | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc index af292de860e3f0..4e2a8c6358a16f 100644 --- a/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc +++ b/third_party/xla/xla/mlir_hlo/utils/hlo_utils.cc @@ -17,21 +17,71 @@ limitations under the License. #include #include +#include +#include +#include #include #include #include #include +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { +namespace { +APFloat getScalarLimitOfFloatType(FloatType floatTy, ScalarLimit limit) { + auto& semantics = floatTy.getFloatSemantics(); + switch (limit) { + case kLowest: + return APFloat::getLargest(semantics, /*negative=*/true); + case kInfinityLowest: + return APFloat::getInf(semantics, /*negative=*/true); + case kMax: + return APFloat::getLargest(semantics, /*negative=*/false); + case kInfinityMax: + return APFloat::getInf(semantics, /*negative=*/false); + } + llvm_unreachable("invalid limit"); +} + +// Returns a scalar value for the given integer type. +// +// The argument 'scalar' describes which scalar value to return. `integer_value` +// is used to specify the integer value for kInteger. For any other scalar, +// integer_value is ignored. +APInt getScalarLimitOfIntegerType(IntegerType integerTy, ScalarLimit limit) { + unsigned width = integerTy.getWidth(); + bool isBool = (width == 1); + switch (limit) { + case kLowest: + case kInfinityLowest: + if (integerTy.isUnsigned() || isBool) { + return APInt::getMinValue(width); + } else { + return APInt::getSignedMinValue(width); + } + + case kMax: + case kInfinityMax: + if (integerTy.isUnsigned() || isBool) { + return APInt::getMaxValue(width); + } else { + return APInt::getSignedMaxValue(width); + } + } + llvm_unreachable("invalid limit"); +} +} // namespace static constexpr size_t kPaddingSize = 64; @@ -110,50 +160,6 @@ DenseElementsAttr getScalarNegZeroOfType(Type ty) { llvm_unreachable("unsupported type"); } -static APFloat getScalarLimitOfFloatType(FloatType floatTy, ScalarLimit limit) { - auto& semantics = floatTy.getFloatSemantics(); - switch (limit) { - case kLowest: - return APFloat::getLargest(semantics, /*negative=*/true); - case kInfinityLowest: - return APFloat::getInf(semantics, /*negative=*/true); - case kMax: - return APFloat::getLargest(semantics, /*negative=*/false); - case kInfinityMax: - return APFloat::getInf(semantics, /*negative=*/false); - } - llvm_unreachable("invalid limit"); -} - -// Returns a scalar value for the given integer type. -// -// The argument 'scalar' describes which scalar value to return. `integer_value` -// is used to specify the integer value for kInteger. For any other scalar, -// integer_value is ignored. -static APInt getScalarLimitOfIntegerType(IntegerType integerTy, - ScalarLimit limit) { - unsigned width = integerTy.getWidth(); - bool isBool = (width == 1); - switch (limit) { - case kLowest: - case kInfinityLowest: - if (integerTy.isUnsigned() || isBool) { - return APInt::getMinValue(width); - } else { - return APInt::getSignedMinValue(width); - } - - case kMax: - case kInfinityMax: - if (integerTy.isUnsigned() || isBool) { - return APInt::getMaxValue(width); - } else { - return APInt::getSignedMaxValue(width); - } - } - llvm_unreachable("invalid limit"); -} - DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit) { RankedTensorType scalarTy = RankedTensorType::get({}, ty); if (auto floatTy = mlir::dyn_cast(ty)) { From f1d8c83302c361ce4d3bd15a0b4112e844654634 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Mon, 8 Dec 2025 14:17:50 -0800 Subject: [PATCH 050/753] Generalize CommonPjRtClient::PrepareArguments for processing all the input argument handles. This requires introducing a new EventSet concept for collecting the definition and device events before passing all of this to the internal Execute() function. PiperOrigin-RevId: 841913034 --- third_party/xla/xla/pjrt/BUILD | 3 + .../xla/pjrt/abstract_tracked_device_buffer.h | 12 ++ .../xla/xla/pjrt/common_pjrt_client.cc | 136 ++++++++++++++++++ third_party/xla/xla/pjrt/common_pjrt_client.h | 11 ++ third_party/xla/xla/pjrt/device_event.h | 7 + 5 files changed, 169 insertions(+) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 6b382142dbb42f..0f895ff6047f2e 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -143,7 +143,9 @@ cc_library( ":device_event", ":host_callback", ":pjrt_client", + ":pjrt_executable", ":raw_buffer", + ":utils", "//xla:future", "//xla:literal", "//xla:shape_util", @@ -155,6 +157,7 @@ cc_library( "//xla/tsl/platform:errors", "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h index 85975e95db0a28..03436e55390afb 100644 --- a/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/abstract_tracked_device_buffer.h @@ -98,6 +98,18 @@ class AbstractTrackedDeviceBuffer { "WaitUntilBufferReadyOnStream is only implemented for GPU."); } + // TODO(parkers): definition events are fixed, so we should just store them + // directly. + // Returns true if there is an error in any of the events. + virtual bool AddDefinitionEventsToSet(PjRtDeviceEventSet& events) { + LOG(FATAL) << "TODO IMPLEMENT: AddDefinitionEventsToSet."; + return false; + } + + virtual void AddUsageEventsToSet(PjRtDeviceEventSet& events) { + LOG(FATAL) << "TODO IMPLEMENT: AddUsageEventsToSet."; + } + protected: void ReleaseDeviceMemory() { raw_buffer_ = tsl::RCReference(); diff --git a/third_party/xla/xla/pjrt/common_pjrt_client.cc b/third_party/xla/xla/pjrt/common_pjrt_client.cc index e737e773ff85f6..d3756104d08943 100644 --- a/third_party/xla/xla/pjrt/common_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/common_pjrt_client.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" @@ -47,7 +48,9 @@ limitations under the License. #include "xla/pjrt/device_event.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/raw_buffer.h" +#include "xla/pjrt/utils.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -426,6 +429,139 @@ void CommonPjRtClient::ScheduleRemoteSend( usage_event_promise->SetError(error); } +absl::Status CommonPjRtClient::PrepareArguments( + const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span donated_params, PjRtDeviceEventSet& extra_deps, + PjRtDeviceEventSet& control_deps, + absl::InlinedVector, 4>& + input_buffers, + absl::InlinedVector& device_buffers, + PjRtDevice* device, int replica, int partition, + absl::Span parameter_device_shapes, bool& is_error) { + input_buffers.reserve(argument_handles.size()); + device_buffers.reserve(argument_handles.size()); + auto donate_it = donated_params.begin(); + { + tsl::profiler::TraceMe t2("Handle inputs"); + // State for `TestBufferDonationClashes`. + absl::flat_hash_map> donation_clashes; + donation_clashes.reserve(argument_handles.size()); + // The first element is the argument index of the donated buffer, and the + // second element is the size in bytes of the donated buffer. + std::vector> donated_buffer_stats; + for (int i = 0; i < argument_handles.size(); ++i) { + PjRtBuffer* handle = argument_handles[i]; + auto* tfrt_buffer = tensorflow::down_cast(handle); + if (tfrt_buffer->device() != device) { + return InvalidArgument( + "Buffer passed to Execute() as argument %d to replica %d is on " + "device %s, but replica is assigned to device %s.", + i, replica, tfrt_buffer->device()->DebugString(), + device->DebugString()); + } + const bool donated_param = + donate_it != donated_params.end() && *donate_it == i; + const bool donation_denied_at_runtime = + options.non_donatable_input_indices.contains(i); + if (donated_param && donation_denied_at_runtime && + tfrt_buffer->on_device_shape().has_layout() && + tfrt_buffer->on_device_shape().layout().memory_space() == + Layout::kHostMemorySpace) { + return absl::UnimplementedError( + "pinned_host buffers do not support donation denial at runtime via " + "`ExecuteOptions::non_donatable_input_indices`"); + } + bool must_donate = donated_param && !donation_denied_at_runtime; + if (must_donate) { + ++donate_it; + if (VLOG_IS_ON(1)) { + TF_ASSIGN_OR_RETURN(size_t on_device_size, + tfrt_buffer->GetOnDeviceSizeInBytes()); + donated_buffer_stats.emplace_back(std::make_pair(i, on_device_size)); + } + } + TF_RETURN_IF_ERROR(TestBufferDonationClashes( + tfrt_buffer, donation_clashes, must_donate, i, replica, partition)); + device_buffers.emplace_back(tfrt_buffer->GetBufferWithHold( + must_donate ? CommonPjRtBuffer::ScopedHold::kDonation + : CommonPjRtBuffer::ScopedHold::kUsage)); + CommonPjRtBuffer::ScopedHold& hold = device_buffers.back(); + if (!hold.ok()) { + return InvalidArgument( + "Invalid buffer passed to Execute() as argument %d to replica %d: " + "%s", + i, replica, hold.status().ToString()); + } + auto* device_buffer = hold.buffer(); + + const bool is_handle_dynamic_shape = + handle->on_device_shape().is_dynamic(); + + const Shape& expected_shape = parameter_device_shapes[i]; + if (device_buffer->raw_buffer()) { + tsl::RCReference actual_buffer = + device_buffer->raw_buffer(); + if (is_handle_dynamic_shape && !expected_shape.is_dynamic()) { + TF_ASSIGN_OR_RETURN(auto handle_logical_device_shape, + handle->logical_on_device_shape()); + auto status_or_buffer = + actual_buffer->RemoveDynamicShapeMetadataIfPresent( + handle_logical_device_shape); + + if (!status_or_buffer.ok()) { + absl::Status status = status_or_buffer.status(); + tsl::errors::AppendToMessage( + &status, absl::StrCat("; Error when preparing the input buffer " + "to Execute() as argument ", + i, " to replica ", replica)); + return status; + } + actual_buffer = std::move(status_or_buffer).value(); + } + input_buffers.push_back(std::move(actual_buffer)); + } else { + is_error = true; + } + + // Definition events are never modified after buffer construction. + is_error |= device_buffer->AddDefinitionEventsToSet(extra_deps); + // If we are trying to donate this buffer, we must wait on its usage + // events as well as its definition events to ensure that all reads on + // this buffer (e.g., d2h transfer) have been completed before it can be + // mutated. Usage holds on this buffer are excluded during a donation hold + // so we know that its usage events won't be modified while we are + // enqueueing, but we ignore any errors from usage events. + if (must_donate) { + device_buffer->AddUsageEventsToSet(control_deps); + } + } + // Debug logging of buffer donation and input buffer shapes and size. + if (VLOG_IS_ON(1)) { + // Buffer donation information. + if (!argument_handles.empty()) { + LOG(INFO) << donated_buffer_stats.size() << " arguments out of total " + << argument_handles.size() << " arguments will be donated."; + for (auto [index, buffer_size] : donated_buffer_stats) { + LOG(INFO) << "Argument " << index << " with size " << buffer_size + << " will be donated."; + } + } + // Input buffers shape and size. + for (int i = 0; i < input_buffers.size(); ++i) { + size_t buffer_size = input_buffers[i]->GetOnDeviceSizeInBytes(); + TF_ASSIGN_OR_RETURN(Shape actual_input_shape, + argument_handles[i]->logical_on_device_shape()); + VLOG(2) << "input buffer with index " << i + << " has shape: " << actual_input_shape.ToString() + << " and size: " << buffer_size; + } + } + } + + return absl::OkStatus(); +} + absl::StatusOr, 4>> CommonPjRtClient::AllocateOutputBuffersWithInputReuse( const Shape& output_device_shape, diff --git a/third_party/xla/xla/pjrt/common_pjrt_client.h b/third_party/xla/xla/pjrt/common_pjrt_client.h index 43cc61bc024696..5470a087376444 100644 --- a/third_party/xla/xla/pjrt/common_pjrt_client.h +++ b/third_party/xla/xla/pjrt/common_pjrt_client.h @@ -236,6 +236,17 @@ class CommonPjRtClient : public PjRtClient { Future serialized_descriptor, PjRtBuffer::RemoteSendCallback on_done); + static absl::Status PrepareArguments( + const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span donated_params, PjRtDeviceEventSet& extra_deps, + PjRtDeviceEventSet& control_deps, + absl::InlinedVector, 4>& + input_buffers, + absl::InlinedVector& device_buffers, + PjRtDevice* device, int replica, int partition, + absl::Span parameter_device_shapes, bool& is_error); + absl::StatusOr, 4>> AllocateOutputBuffersWithInputReuse( const Shape& output_device_shape, diff --git a/third_party/xla/xla/pjrt/device_event.h b/third_party/xla/xla/pjrt/device_event.h index 9aa231ebca926f..5e307e0dcc5cc9 100644 --- a/third_party/xla/xla/pjrt/device_event.h +++ b/third_party/xla/xla/pjrt/device_event.h @@ -106,6 +106,13 @@ class PjRtDeviceEventPromise : public PjRtDeviceEventOrPromise { virtual void SetReady() = 0; }; +// A collection of events. This is not an event itself because we may want to +// add events in the future. +class PjRtDeviceEventSet { + public: + virtual ~PjRtDeviceEventSet() = default; +}; + } // namespace xla #endif // XLA_PJRT_DEVICE_EVENT_H_ From cb1ad8a22b10a96c3308b16749dd6737b634b929 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Mon, 8 Dec 2025 14:18:14 -0800 Subject: [PATCH 051/753] [XLA:CPU][XTile] Rewrite llvm vectorized log to polynomial approximations. This is needed to get the same numerics from the tiled & scalar emitter. PiperOrigin-RevId: 841913170 --- .../xla/xla/backends/cpu/codegen/polynomial_approximations.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc index 232bfe0488ba26..0c8084568e41c5 100644 --- a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc +++ b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc @@ -541,6 +541,10 @@ void RewriteToPolynomialApproximations(llvm::Module* module, rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1); rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1); + rewrite_calls("llvm.log.v2f32", GenerateVF32Log, /*vector_width=*/2); + rewrite_calls("llvm.log.v4f32", GenerateVF32Log, /*vector_width=*/4); + rewrite_calls("llvm.log.v8f32", GenerateVF32Log, /*vector_width=*/8); + rewrite_calls("llvm.log.v16f32", GenerateVF32Log, /*vector_width=*/16); rewrite_calls(kLogV4F32Sym, GenerateVF32Log, /*vector_width=*/4); rewrite_calls(kLogV8F32Sym, GenerateVF32Log, /*vector_width=*/8); rewrite_calls(kLogV16F32Sym, GenerateVF32Log, /*vector_width=*/16); From b59325607d1ce30065a0a6552e7feb8d7d3dbadc Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Mon, 8 Dec 2025 15:18:41 -0800 Subject: [PATCH 052/753] Make a variant of Get in AttributeMap that returns a variant type. PiperOrigin-RevId: 841934935 --- third_party/xla/xla/python/ifrt/BUILD | 1 - third_party/xla/xla/python/ifrt/attribute_map.h | 12 ++++++++---- .../xla/xla/python/ifrt/attribute_map_test.cc | 6 +----- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 852c0d29384626..1efa12d993c918 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -179,7 +179,6 @@ xla_cc_test( "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], diff --git a/third_party/xla/xla/python/ifrt/attribute_map.h b/third_party/xla/xla/python/ifrt/attribute_map.h index 714b437c8f092b..b1823d2d88b562 100644 --- a/third_party/xla/xla/python/ifrt/attribute_map.h +++ b/third_party/xla/xla/python/ifrt/attribute_map.h @@ -95,15 +95,19 @@ class AttributeMap { template absl::StatusOr Get(const std::string& key) const { - if constexpr (std::is_same_v || - std::is_same_v) { + if constexpr (std::is_same_v) { + auto it = map_.find(key); + if (it == map_.end()) { + return absl::NotFoundError(absl::StrCat("Key not found: ", key)); + } + return it->second; + } else if constexpr (std::is_same_v) { return Get(key); } else if constexpr (std::is_same_v) { return Get(key); } else if constexpr (std::is_same_v) { return Get(key); - } else if constexpr (std::is_same_v> || - std::is_same_v>) { + } else if constexpr (std::is_same_v>) { return Get(key); } else if constexpr (std::is_same_v) { return Get(key); diff --git a/third_party/xla/xla/python/ifrt/attribute_map_test.cc b/third_party/xla/xla/python/ifrt/attribute_map_test.cc index 96069fdfb6ee74..c8425ec79038a4 100644 --- a/third_party/xla/xla/python/ifrt/attribute_map_test.cc +++ b/third_party/xla/xla/python/ifrt/attribute_map_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/status_matchers.h" -#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/python/ifrt/serdes_test_util.h" #include "xla/python/ifrt/serdes_version.h" @@ -67,19 +66,16 @@ TEST(AttributeMapTest, Get) { }); EXPECT_THAT(map.Get("string"), IsOkAndHolds("value")); - EXPECT_THAT(map.Get("string"), IsOkAndHolds("value")); EXPECT_THAT(map.Get("bool"), IsOkAndHolds(true)); EXPECT_THAT(map.Get("int64"), IsOkAndHolds(123)); EXPECT_THAT(map.Get>("int64_list"), IsOkAndHolds(std::vector{1, 2})); - EXPECT_THAT(map.Get>("int64_list"), - IsOkAndHolds(std::vector{1, 2})); EXPECT_THAT(map.Get("float"), IsOkAndHolds(1.23f)); EXPECT_THAT(map.Get("float"), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Value type mismatch for key: float"))); - EXPECT_THAT(map.Get>("string"), + EXPECT_THAT(map.Get>("string"), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Value type mismatch for key: string"))); } From 4cd611b520678b3b7d39094a6278acaf8b9b3370 Mon Sep 17 00:00:00 2001 From: Krishna Haridasan Date: Mon, 8 Dec 2025 17:01:13 -0800 Subject: [PATCH 053/753] Add ForEach method to IFRT AttributeMap PiperOrigin-RevId: 841969101 --- third_party/xla/xla/python/ifrt/BUILD | 2 +- .../xla/xla/python/ifrt/attribute_map.h | 17 +++++++++++++++-- .../xla/xla/python/ifrt/attribute_map_test.cc | 18 +++++++++--------- .../pjrt_ifrt/pjrt_attribute_map_util.cc | 10 +++++----- .../pjrt_ifrt/pjrt_attribute_map_util_test.cc | 4 ++-- .../pjrt_ifrt/xla_executable_impl_test_lib.cc | 2 +- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 1efa12d993c918..c411c99ab65d9c 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -160,11 +160,11 @@ cc_library( ":serdes_version", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/python/ifrt/attribute_map.h b/third_party/xla/xla/python/ifrt/attribute_map.h index b1823d2d88b562..e53be413eb1c4a 100644 --- a/third_party/xla/xla/python/ifrt/attribute_map.h +++ b/third_party/xla/xla/python/ifrt/attribute_map.h @@ -26,11 +26,10 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" #include "xla/python/ifrt/attribute_map.pb.h" #include "xla/python/ifrt/serdes_default_version_accessor.h" #include "xla/python/ifrt/serdes_version.h" @@ -141,6 +140,20 @@ class AttributeMap { bool IsEmpty() const { return map_.empty(); } + // Invokes `f` for each key-value pair in the attribute map. + void ForEach( + absl::FunctionRef f) const { + for (const auto& [key, value] : map_) { + f(key, value); + } + } + + bool operator==(const AttributeMap& other) const { + return map_ == other.map_; + } + + size_t size() const { return map_.size(); } + private: template absl::StatusOr Get(const std::string& key) const { diff --git a/third_party/xla/xla/python/ifrt/attribute_map_test.cc b/third_party/xla/xla/python/ifrt/attribute_map_test.cc index c8425ec79038a4..c658838e2af1c3 100644 --- a/third_party/xla/xla/python/ifrt/attribute_map_test.cc +++ b/third_party/xla/xla/python/ifrt/attribute_map_test.cc @@ -45,14 +45,14 @@ TEST(AttributeMapTest, MapElements) { {"float", AttributeMap::FloatValue(1.23f)}, }); - EXPECT_EQ(map.map(), AttributeMap::Map({ - {"string", AttributeMap::StringValue("value")}, - {"bool", AttributeMap::BoolValue(true)}, - {"int64", AttributeMap::Int64Value(123)}, - {"int64_list", AttributeMap::Int64ListValue( - {int64_t{1}, int64_t{2}})}, - {"float", AttributeMap::FloatValue(1.23f)}, - })) + EXPECT_EQ(map, AttributeMap({ + {"string", AttributeMap::StringValue("value")}, + {"bool", AttributeMap::BoolValue(true)}, + {"int64", AttributeMap::Int64Value(123)}, + {"int64_list", + AttributeMap::Int64ListValue({int64_t{1}, int64_t{2}})}, + {"float", AttributeMap::FloatValue(1.23f)}, + })) << map.DebugString(); } @@ -101,7 +101,7 @@ TEST_P(AttributeMapSerDesTest, ToFromProto) { TF_ASSERT_OK_AND_ASSIGN(auto map_copy, AttributeMap::FromProto(map.ToProto(version()))); - EXPECT_EQ(map_copy.map(), map.map()) << map_copy.DebugString(); + EXPECT_EQ(map_copy, map) << map_copy.DebugString(); } INSTANTIATE_TEST_SUITE_P( diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc index af2a07cb85d92f..a28a1cfa8cf481 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util.cc @@ -59,12 +59,12 @@ AttributeMap FromPjRtAttributeMap( absl::flat_hash_map ToPjRtAttributeMap( AttributeMap attributes) { absl::flat_hash_map result; - result.reserve(attributes.map().size()); - for (auto& item : attributes.map()) { + result.reserve(attributes.size()); + attributes.ForEach([&](const std::string& key, + const AttributeMap::Value& value) { std::visit( [&](auto& value) { using T = std::decay_t; - const auto& key = item.first; if constexpr (std::is_same_v) { result.insert({key, std::move(value.value)}); } else if constexpr (std::is_same_v) { @@ -78,8 +78,8 @@ absl::flat_hash_map ToPjRtAttributeMap( result.insert({key, value.value}); } }, - item.second); - } + value); + }); return result; } diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc index afee66155aa4ad..dd8742d16610ad 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_attribute_map_util_test.cc @@ -38,8 +38,8 @@ TEST(PjRtAttributeMapUtilTest, FromPjRtAttributeMap) { {"float", xla::PjRtValueType(1.23f)}, }); - EXPECT_EQ(FromPjRtAttributeMap(pjrt_map).map(), - AttributeMap::Map({ + EXPECT_EQ(FromPjRtAttributeMap(pjrt_map), + AttributeMap({ {"string", AttributeMap::StringValue("value")}, {"bool", AttributeMap::BoolValue(true)}, {"int64", AttributeMap::Int64Value(123)}, diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 31cc3645eaead6..54007917478279 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -294,7 +294,7 @@ TEST_P(LoadedExecutableImplTest, Analysis) { TF_ASSERT_OK_AND_ASSIGN(const auto cost_analysis, executable->GetCostAnalysis()); - EXPECT_THAT(cost_analysis.map(), Not(IsEmpty())); + EXPECT_FALSE(cost_analysis.IsEmpty()); } TEST_P(LoadedExecutableImplTest, GetDonatableInputIndices) { From b8a2067837dab89276b5f0d4f2f104b4ccf3e975 Mon Sep 17 00:00:00 2001 From: Maxim Ermilov Date: Mon, 8 Dec 2025 18:32:31 -0800 Subject: [PATCH 054/753] Add Shape to MemzeroThunk/MemzeroCmd buffer_uses Modify Thunk's serialization PiperOrigin-RevId: 841995197 --- third_party/xla/xla/backends/gpu/runtime/BUILD | 8 +++++++- .../backends/gpu/runtime/command_buffer_cmd.cc | 15 ++++++++------- .../xla/backends/gpu/runtime/command_buffer_cmd.h | 6 +++--- .../gpu/runtime/command_buffer_cmd_test.cc | 7 ++++++- .../gpu/runtime/command_buffer_thunk_test.cc | 5 ++++- .../xla/xla/backends/gpu/runtime/memset_thunk.cc | 9 +++++---- .../xla/xla/backends/gpu/runtime/memset_thunk.h | 10 +++++----- .../xla/backends/gpu/runtime/memset_thunk_test.cc | 9 ++++++++- .../xla/xla/backends/gpu/runtime/thunk.proto | 2 +- 9 files changed, 47 insertions(+), 24 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index c40f279f9212dd..856ea1bc586623 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -88,6 +88,7 @@ cc_library( "//xla:util", "//xla/backends/gpu/collectives:gpu_clique_key", "//xla/core/collectives:communicator", + "//xla/core/collectives:reduction_kind", "//xla/ffi:attribute_map", "//xla/ffi:call_frame", "//xla/ffi:execution_state", @@ -101,7 +102,6 @@ cc_library( "//xla/runtime:object_pool", "//xla/runtime:resource_use", "//xla/service:buffer_assignment", - "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:custom_call_status_internal", "//xla/service:custom_call_status_public_headers", @@ -119,6 +119,7 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:tensor_map", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/gpu:tma_metadata", "//xla/tsl/lib/gtl:int_type", @@ -152,7 +153,9 @@ xla_test( deps = [ ":command_buffer_cmd", ":copy_thunk", + ":shaped_slice", ":thunk", + "//xla:shape_util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:executable", @@ -359,6 +362,7 @@ xla_test( ":gpublas_lt_matmul_thunk", ":memset_thunk", ":sequential_thunk", + ":shaped_slice", ":thunk", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -377,6 +381,7 @@ xla_test( "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_args", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", @@ -1190,6 +1195,7 @@ cc_library( srcs = ["memset_thunk.cc"], hdrs = ["memset_thunk.h"], deps = [ + ":shaped_slice", ":thunk", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc index 37428162a4b910..866e9a9f0b870b 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -62,6 +62,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/thunk.h" #include "xla/backends/gpu/runtime/while_thunk.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/reduction_kind.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/ffi/call_frame.h" @@ -74,7 +75,6 @@ limitations under the License. #include "xla/runtime/execution_graph.h" #include "xla/runtime/resource_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -95,6 +95,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/tensor_map.h" #include "xla/stream_executor/trace_command_buffer_factory.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -1369,7 +1370,7 @@ CommandBufferCmd::BufferUseVector MemcpyDeviceToDeviceCmd::buffers() const { // MemzeroCmd //===----------------------------------------------------------------------===// -MemzeroCmd::MemzeroCmd(BufferAllocation::Slice dst) +MemzeroCmd::MemzeroCmd(ShapedSlice dst) : CommandBufferCmd(CommandBufferCmdType::kMemzeroCmd), dst_(dst) {} absl::StatusOr MemzeroCmd::Record( @@ -1377,12 +1378,12 @@ absl::StatusOr MemzeroCmd::Record( const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer) { se::DeviceAddressBase dst = - execute_params.buffer_allocations->GetDeviceAddress(dst_); + execute_params.buffer_allocations->GetDeviceAddress(dst_.slice); VLOG(5) << "MemzeroCmd:"; VLOG(5) << " Dst: " << dst_ << " (" << dst.opaque() << ")"; - if (dst_.size() == 0) { + if (dst_.slice.size() == 0) { VLOG(5) << "Skip recording MemzeroCmd command of 0 bytes"; return nullptr; } @@ -1391,17 +1392,17 @@ absl::StatusOr MemzeroCmd::Record( std::move(record_action), [&](absl::Span dependencies) { return command_buffer->CreateMemset(&dst, uint8_t{0}, - /*num_elements=*/dst_.size(), + /*num_elements=*/dst_.slice.size(), dependencies); }, [&](const se::CommandBuffer::Command* command) { return command_buffer->UpdateMemset(command, &dst, uint8_t{0}, - /*num_elements=*/dst_.size()); + /*num_elements=*/dst_.slice.size()); }); } CommandBufferCmd::BufferUseVector MemzeroCmd::buffers() const { - return {BufferUse::Write(dst_)}; + return {BufferUse::Write(dst_.slice, dst_.shape)}; } //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 40a5b9cff1a7c1..70114c4b8a2e00 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -46,6 +46,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/p2p_thunk_common.h" #include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" +#include "xla/core/collectives/reduction_kind.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" @@ -56,7 +57,6 @@ limitations under the License. #include "xla/runtime/object_pool.h" #include "xla/runtime/resource_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" @@ -767,7 +767,7 @@ class MemcpyDeviceToDeviceCmd : public CommandBufferCmd { class MemzeroCmd : public CommandBufferCmd { public: - explicit MemzeroCmd(BufferAllocation::Slice dst); + explicit MemzeroCmd(ShapedSlice dst); absl::StatusOr Record( const Thunk::ExecuteParams& execute_params, @@ -777,7 +777,7 @@ class MemzeroCmd : public CommandBufferCmd { BufferUseVector buffers() const override; private: - BufferAllocation::Slice dst_; + ShapedSlice dst_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 5dffa09d49e184..c8e59ba1093233 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/backends/gpu/runtime/copy_thunk.h" +#include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" @@ -36,6 +37,8 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/platform_util.h" #include "xla/service/service_executable_run_options.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_address.h" #include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" @@ -721,6 +724,8 @@ TEST(CommandBufferCmdTest, NestedChildCmdCreateAndUpdate) { // Prepare device memory for three buffers. int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; + Shape shape = ShapeUtil::MakeShape(S32, {length}); + se::DeviceAddress a = stream_executor->AllocateArray(length); se::DeviceAddress b = @@ -763,7 +768,7 @@ TEST(CommandBufferCmdTest, NestedChildCmdCreateAndUpdate) { CommandBufferCmdSequence outer_seq; outer_seq.Emplace(std::move(middle_executor)); // Add a couple more commands at the outer level that still don't affect `c`. - outer_seq.Emplace(slice_b); + outer_seq.Emplace(ShapedSlice{slice_b, shape}); outer_seq.Emplace(); TF_ASSERT_OK_AND_ASSIGN( CommandBufferCmdExecutor outer_executor, diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc index d4a472d5c542e5..fff417ba5dacbe 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_thunk_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/backends/gpu/runtime/memset_thunk.h" #include "xla/backends/gpu/runtime/sequential_thunk.h" +#include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" @@ -54,6 +55,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_args.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -222,6 +224,7 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { int64_t length = 4; int64_t byte_length = sizeof(int32_t) * length; + Shape shape = ShapeUtil::MakeShape(S32, {length}); // Prepare arguments: a=42 se::DeviceAddress a = @@ -234,7 +237,7 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(slice_a); + commands.Emplace(ShapedSlice{slice_a, shape}); TF_ASSERT_OK_AND_ASSIGN( CommandBufferCmdExecutor executor, CommandBufferCmdExecutor::Create(std::move(commands), serialize)); diff --git a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc index a370a04753620b..75508e2365b0b8 100644 --- a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_address.h" #include "xla/tsl/platform/statusor.h" @@ -30,16 +31,16 @@ namespace gpu { absl::Status MemzeroThunk::ExecuteOnStream(const ExecuteParams& params) { se::DeviceAddressBase dest_data = - params.buffer_allocations->GetDeviceAddress(dest_); + params.buffer_allocations->GetDeviceAddress(dest_.slice); return params.stream->MemZero(&dest_data, dest_data.size()); } absl::StatusOr> MemzeroThunk::FromProto( ThunkInfo thunk_info, const MemzeroThunkProto& thunk_proto, absl::Span buffer_allocations) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest, - BufferAllocation::Slice::FromProto( - thunk_proto.dest_buffer(), buffer_allocations)); + TF_ASSIGN_OR_RETURN( + ShapedSlice dest, + ShapedSlice::FromProto(thunk_proto.dest_buffer(), buffer_allocations)); return std::make_unique(std::move(thunk_info), dest); } diff --git a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.h b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.h index 6b627180c9f5af..aec432fa785d82 100644 --- a/third_party/xla/xla/backends/gpu/runtime/memset_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/memset_thunk.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/gpu/runtime/shaped_slice.h" #include "xla/backends/gpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" @@ -35,17 +36,16 @@ namespace gpu { // Thunk that zeroes out a given chunk of memory. class MemzeroThunk : public Thunk { public: - explicit MemzeroThunk(ThunkInfo thunk_info, - const BufferAllocation::Slice& dest) + explicit MemzeroThunk(ThunkInfo thunk_info, const ShapedSlice& dest) : Thunk(Kind::kMemzero, thunk_info), dest_(dest) {} absl::Status ExecuteOnStream(const ExecuteParams& params) override; - const BufferAllocation::Slice& destination() const { return dest_; } + const ShapedSlice& destination() const { return dest_; } BufferUses buffer_uses() const override { return { - BufferUse::Write(dest_), + BufferUse::Write(dest_.slice, dest_.shape), }; } @@ -56,7 +56,7 @@ class MemzeroThunk : public Thunk { absl::StatusOr ToProto() const override; private: - const BufferAllocation::Slice dest_; + const ShapedSlice dest_; }; // Thunk that sets a given chunk of memory to a particular 32-bit value. The diff --git a/third_party/xla/xla/backends/gpu/runtime/memset_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/memset_thunk_test.cc index 0eb1bc60ff2cb3..67a6d8044ab5a4 100644 --- a/third_party/xla/xla/backends/gpu/runtime/memset_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/memset_thunk_test.cc @@ -41,7 +41,14 @@ TEST(MemzeroThunkTest, ProtoRoundTrip) { execution_stream_id: 2 } memzero_thunk { - dest_buffer { offset: 0 size: 4 buffer_allocation_index: 0 } + dest_buffer { + slice { offset: 0 size: 4 buffer_allocation_index: 0 } + shape { + dimensions: 1 + element_type: F32 + is_dynamic_dimension: false + } + } } )pb"); std::vector buffer_allocations = { diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk.proto b/third_party/xla/xla/backends/gpu/runtime/thunk.proto index c34eabae9e45f4..7b9bbf093b6863 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk.proto +++ b/third_party/xla/xla/backends/gpu/runtime/thunk.proto @@ -191,7 +191,7 @@ message DynamicSliceThunkProto { } message MemzeroThunkProto { - xla.buffer_assignment.BufferAllocationSliceProto dest_buffer = 1; + ShapedSliceProto dest_buffer = 1; } message Memset32BitValueThunkProto { From eb43894883764e9732a3118912e454c520969659 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 21:20:16 -0800 Subject: [PATCH 055/753] Automated Code Change PiperOrigin-RevId: 842047558 --- third_party/xla/xla/runtime/buffer_use.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/runtime/buffer_use.cc b/third_party/xla/xla/runtime/buffer_use.cc index 23030c71cffac6..10aeab882b9e27 100644 --- a/third_party/xla/xla/runtime/buffer_use.cc +++ b/third_party/xla/xla/runtime/buffer_use.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" From 8828f2a418fa1f9c2dedc32294529c135f050251 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 22:34:14 -0800 Subject: [PATCH 056/753] Automated Code Change PiperOrigin-RevId: 842069505 --- third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index c6318aaa7f02d1..cfc07e10bfb631 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include #include "absl/algorithm/container.h" From 9e24441327c16113bb1e1f578616209dc8ae1bf7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 8 Dec 2025 22:50:48 -0800 Subject: [PATCH 057/753] [stream_executor] Switch SE TPU backend to se::DeviceMemoryAddress Renaming types in preparation for introducing physical memory allocation concept to SE. PiperOrigin-RevId: 842074373 --- tensorflow/core/tpu/tpu_execute.cc | 4 ++-- .../xla/stream_executor/tpu/c_api_conversions.cc | 12 ++++++------ .../xla/stream_executor/tpu/c_api_conversions.h | 8 ++++---- .../stream_executor/tpu/c_api_conversions_test.cc | 2 +- .../xla/xla/stream_executor/tpu/c_api_decl.h | 14 ++++---------- .../xla/xla/stream_executor/tpu/tpu_executable.cc | 6 +++--- .../tpu/tpu_executable_interface.cc | 8 ++++---- .../xla/xla/stream_executor/tpu/tpu_executor.h | 4 ++-- .../xla/stream_executor/tpu/tpu_executor_c_api.h | 8 ++++---- .../stream_executor/tpu/tpu_executor_init_fns.inc | 2 +- .../stream_executor/tpu/tpu_executor_interface.h | 9 +++------ 11 files changed, 34 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/tpu/tpu_execute.cc b/tensorflow/core/tpu/tpu_execute.cc index 865683dcb430cf..a8edf650bc1718 100644 --- a/tensorflow/core/tpu/tpu_execute.cc +++ b/tensorflow/core/tpu/tpu_execute.cc @@ -474,7 +474,7 @@ absl::StatusOr TPUExecute( VLOG(1) << "TPUExecute: Updating TPUEmbedding memory addresses on " << device_ordinal; - SE_DeviceMemoryBase* device_memory_addrs = nullptr; + SE_DeviceAddressBase* device_memory_addrs = nullptr; size_t device_memory_addrs_count; auto device_memory_cleanup = absl::MakeCleanup([device_memory_addrs, device_ordinal]() { @@ -501,7 +501,7 @@ absl::StatusOr TPUExecute( for (int i = 0; i < device_memory_addrs_count; ++i) { xla::ShapeTree tree( xla::ShapeUtil::MakeOpaqueShape()); - const SE_DeviceMemoryBase& addr = device_memory_addrs[i]; + const SE_DeviceAddressBase& addr = device_memory_addrs[i]; VLOG(2) << absl::StrFormat("Device memory addr[%i] = {%p, %llu, %llu}", i, addr.opaque, addr.size, addr.payload); *tree.mutable_element({}) = ApiConverter::FromC(addr); diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index 58eb6c2c3033f9..cb53f7c79336dc 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -158,9 +158,9 @@ xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer) { return xla_shaped_buffer; } -SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceAddress& mem, - bool aliased) { - SE_MaybeOwningDeviceMemory se_mem; +SE_MaybeOwningDeviceAddress ToC(xla::MaybeOwningDeviceAddress& mem, + bool aliased) { + SE_MaybeOwningDeviceAddress se_mem; se_mem.owned = mem.HasOwnership(); se_mem.memory = ApiConverter::ToC(mem.AsDeviceAddress()); if (mem.HasOwnership()) { @@ -181,7 +181,7 @@ SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceAddress& mem, } xla::MaybeOwningDeviceAddress FromC( - SE_MaybeOwningDeviceMemory* se_mem, + SE_MaybeOwningDeviceAddress* se_mem, stream_executor::DeviceAddressAllocator* allocator) { if (se_mem->owned) { return xla::MaybeOwningDeviceAddress(stream_executor::OwningDeviceAddress( @@ -244,8 +244,8 @@ stream_executor::DeviceAddressAllocator* FromC( c_allocator.ctx); } -SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceAddress* mem) { - SE_MaybeOwningDeviceMemory se_mem; +SE_MaybeOwningDeviceAddress ToC(stream_executor::OwningDeviceAddress* mem) { + SE_MaybeOwningDeviceAddress se_mem; se_mem.device_ordinal = mem->device_ordinal(); se_mem.memory = ApiConverter::ToC(mem->Release()); se_mem.allocator = ApiConverter::ToC(mem->allocator()); diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h index da3db36c17a1d2..cdfcab80fabb69 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.h @@ -118,7 +118,7 @@ struct TpuEmbeddingEngineParametersData { std::unique_ptr Create(int num_tables); xla::MaybeOwningDeviceAddress FromC( - SE_MaybeOwningDeviceMemory* se_mem, + SE_MaybeOwningDeviceAddress* se_mem, stream_executor::DeviceAddressAllocator* allocator); // DeviceAddressAllocator @@ -128,12 +128,12 @@ stream_executor::DeviceAddressAllocator* FromC( const SE_DeviceAddressAllocator& c_allocator); // OwningDeviceAddress -SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceAddress* mem); +SE_MaybeOwningDeviceAddress ToC(stream_executor::OwningDeviceAddress* mem); // mem.HasOwnership() may be true if the buffer is aliased and shouldn't be // released. 'aliased' should be true in this case. 'aliased' has no effect if // 'mem' is unowned. -SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceAddress& mem, - bool aliased); +SE_MaybeOwningDeviceAddress ToC(xla::MaybeOwningDeviceAddress& mem, + bool aliased); // HloModule XLA_HloModule ToC(const xla::HloModule& module); diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc index 05ec51c5e79ea8..c96e6be263d884 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -308,7 +308,7 @@ TEST(XlaHloModule, ToAndFromC) { } // TODO(b/290654348): SE_DeviceAddressBase, SE_DeviceAddressAllocator, -// SE_MaybeOwningDeviceMemory +// SE_MaybeOwningDeviceAddress } // namespace diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 834a3da9f4ed0d..a42221294fa16c 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -72,15 +72,11 @@ typedef struct SE_DeviceAddressBase { uint64_t payload; } SE_DeviceAddressBase; -typedef SE_DeviceAddressBase SE_DeviceMemoryBase; - typedef struct SE_ScopedDeviceAddress { SE_DeviceAddressBase wrapped; int device_ordinal; } SE_ScopedDeviceAddress; -typedef SE_ScopedDeviceAddress SE_ScopedDeviceMemory; - typedef struct SE_AllocatorStats { int64_t num_allocs; int64_t bytes_in_use; @@ -117,8 +113,6 @@ typedef struct SE_DeviceAddressAllocator { SE_DeallocateFn deallocate; } SE_DeviceAddressAllocator; -typedef SE_DeviceAddressAllocator SE_DeviceMemoryAllocator; - typedef struct SE_DeviceDescription { char* device_vendor; char* platform_version; @@ -175,14 +169,14 @@ typedef struct SE_ExecutableRunOptions { typedef struct SE_ExecutableSerializationHandle SE_ExecutableSerializationHandle; -typedef struct SE_MaybeOwningDeviceMemory { +typedef struct SE_MaybeOwningDeviceAddress { SE_DeviceAddressBase memory; bool owned; // Set if owned int device_ordinal; SE_DeviceAddressAllocator allocator; -} SE_MaybeOwningDeviceMemory; +} SE_MaybeOwningDeviceAddress; typedef struct IntList { union { @@ -279,7 +273,7 @@ typedef struct XLA_Literal { typedef struct XLA_MaybeOwningDeviceAddressShapeTree { XLA_Shape shape; - SE_MaybeOwningDeviceMemory* buffers; + SE_MaybeOwningDeviceAddress* buffers; } XLA_MaybeOwningDeviceAddressShapeTree; typedef struct XLA_ShapeIndex { @@ -296,7 +290,7 @@ typedef struct SE_ExecutionInput { typedef struct SE_ExecutionOutput { XLA_ShapedBuffer result; - SE_MaybeOwningDeviceMemory* to_be_released; + SE_MaybeOwningDeviceAddress* to_be_released; int to_be_released_size; XLA_ShapeIndex* aliased_indices; int aliased_indices_size; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc index 3ef4c531a066fd..b5f5c6d80017ab 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable.cc @@ -102,13 +102,13 @@ absl::StatusOr TpuExecutable::ExecuteAsyncOnStream( ApiConverter::ToC(arg.shape(), &se_args[i]->shape_tree.shape); auto* arg_buffers = arg.MutableBuffers(); - absl::InlinedVector se_buffers; + absl::InlinedVector se_buffers; for (auto& pair : *arg_buffers) { bool aliased = arg.unowned_indices().count(pair.first) > 0; se_buffers.push_back(ApiConverter::ToC(pair.second, aliased)); } se_args[i]->shape_tree.buffers = - new SE_MaybeOwningDeviceMemory[se_buffers.size()]; + new SE_MaybeOwningDeviceAddress[se_buffers.size()]; for (int j = 0; j < se_buffers.size(); ++j) { se_args[i]->shape_tree.buffers[j] = se_buffers[j]; } @@ -166,7 +166,7 @@ absl::StatusOr TpuExecutable::ExecuteAsyncOnStream( .Release() .value()); } - ExecutorApiFn()->TpuExecutable_FreeMaybeOwningDeviceMemoryArrayFn( + ExecutorApiFn()->TpuExecutable_FreeMaybeOwningDeviceAddressArrayFn( se_execution_output.to_be_released); return output; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc index ab8616ddc8ecc4..f8080a29d01fb3 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executable_interface.cc @@ -212,7 +212,7 @@ absl::StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( std::vector memory_bases; memory_bases.reserve(arguments.size()); for (auto& argument : arguments) { - memory_bases.push_back(argument.Buffer({}).AsDeviceMemoryBase()); + memory_bases.push_back(argument.Buffer({}).AsDeviceAddress()); } se::Stream* stream = run_options->stream(); @@ -240,16 +240,16 @@ absl::StatusOr TpuExecutableInterface::ExecuteAsyncOnStream( // data from fast memory instead of fresh data in large memory. auto it = arguments[parameter].MutableBuffers()->find({index}); CHECK(it != arguments[parameter].MutableBuffers()->end()); - CHECK(!it->second.AsDeviceMemoryBase().is_null()); + CHECK(!it->second.AsDeviceAddress().is_null()); CHECK(offset); bool is_prefetch_output_alias = absl::c_any_of(result.Result().buffers(), [&](auto index_addr_pair) { return index_addr_pair.second.IsSameAs( - it->second.AsDeviceMemoryBase()); + it->second.AsDeviceAddress()); }); cross_program_prefetch_addrs.emplace_back( is_prefetch_output_alias ? stream_executor::DeviceAddressBase() - : it->second.AsDeviceMemoryBase()); + : it->second.AsDeviceAddress()); cross_program_prefetch_offsets.emplace_back( is_prefetch_output_alias ? std::numeric_limits::max() : *offset); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 8209ec55e0b12c..7edaba7a8e9fdd 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -118,8 +118,8 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { // to TpuExecutorInterface. absl::StatusOr> - CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset, - int64_t size) override { + CreateTemporaryDeviceAddress(int64_t memory_space, int64_t byte_offset, + int64_t size) override { LOG(FATAL) << "Unimplemented."; } diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h index ce57d254450d4e..7f2c1b02e094b2 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h @@ -298,12 +298,12 @@ TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream( TFTPU_CAPI_EXPORT void TpuExecutable_FreeXlaShapeIndexArray( XLA_ShapeIndex* array); -// This frees the SE_MaybeOwningDeviceMemory* array allocated when se_output is +// This frees the SE_MaybeOwningDeviceAddress* array allocated when se_output is // returned by TpuExecutable_ExecuteAsyncOnStream. // Note that this only frees the heap-allocated array itself, and does not // free any of the underlying device memory. -TFTPU_CAPI_EXPORT void TpuExecutable_FreeMaybeOwningDeviceMemoryArray( - SE_MaybeOwningDeviceMemory* array); +TFTPU_CAPI_EXPORT void TpuExecutable_FreeMaybeOwningDeviceAddressArray( + SE_MaybeOwningDeviceAddress* array); TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable, const char** fingerprint, @@ -479,7 +479,7 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream); TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_FreeXlaShapeIndexArray); - TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_FreeMaybeOwningDeviceMemoryArray); + TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_FreeMaybeOwningDeviceAddressArray); TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint); TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Serialize); TFTPU_ADD_FN_IN_STRUCT(TpuExecutableSerialize_GetByteSize); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc index 5bc6a8ac9c4086..ee02abad1bf401 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc @@ -125,7 +125,7 @@ absl::Status SetExecutorStructFn( TFTPU_SET_FN(executor_fn, TpuCompiler_DefaultDeviceShapeRepresentation); TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream); TFTPU_SET_FN(executor_fn, TpuExecutable_FreeXlaShapeIndexArray); - TFTPU_SET_FN(executor_fn, TpuExecutable_FreeMaybeOwningDeviceMemoryArray); + TFTPU_SET_FN(executor_fn, TpuExecutable_FreeMaybeOwningDeviceAddressArray); TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint); TFTPU_SET_FN(executor_fn, TpuExecutable_Serialize); TFTPU_SET_FN(executor_fn, TpuExecutableSerialize_GetByteSize); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_interface.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor_interface.h index 6012bb3752dd4f..db95ca86242a95 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_interface.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_interface.h @@ -42,15 +42,12 @@ class TpuExecutorInterface : public stream_executor::StreamExecutorCommon { class TemporaryDeviceAddress { public: virtual ~TemporaryDeviceAddress() {} - virtual stream_executor::DeviceAddressBase AsDeviceMemoryBase() const = 0; + virtual stream_executor::DeviceAddressBase AsDeviceAddress() const = 0; }; - using TemporaryDeviceMemory ABSL_DEPRECATE_AND_INLINE() = - TemporaryDeviceAddress; - virtual absl::StatusOr> - CreateTemporaryDeviceMemory(int64_t memory_space, int64_t byte_offset, - int64_t size) { + CreateTemporaryDeviceAddress(int64_t memory_space, int64_t byte_offset, + int64_t size) { LOG(FATAL) << "Unimplemented."; } From 7b476a6047be6de8899eb1cae660ac84ea3bc3db Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 8 Dec 2025 23:04:22 -0800 Subject: [PATCH 058/753] Automated Code Change PiperOrigin-RevId: 842078054 --- third_party/xla/xla/service/spmd/BUILD | 1 + third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h | 1 + 2 files changed, 2 insertions(+) diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index ff5fdaf4ea5b0b..30ab20ab0b5cc6 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -279,6 +279,7 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service:call_graph", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", ], diff --git a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h index fc0e1269962e53..3c39561468aa99 100644 --- a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" From 419991f7da1549cadeecb5a1ee3166075fb57726 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 8 Dec 2025 23:42:08 -0800 Subject: [PATCH 059/753] [xla:gpu] Document CommandBufferCmd statelessness and StateManager PiperOrigin-RevId: 842088918 --- .../backends/gpu/runtime/command_buffer_cmd.h | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h index 70114c4b8a2e00..31dc41b43596d0 100644 --- a/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -118,6 +118,8 @@ std::string CommandBufferCmdString(CommandBufferCmdType type); // CommandBufferCmd //===----------------------------------------------------------------------===// +using ResourceUseVector = absl::InlinedVector; + // Command is a Thunk counterpart that instead of launching operations directly // on the underlying device records them into command buffers. // @@ -127,9 +129,41 @@ std::string CommandBufferCmdString(CommandBufferCmdType type); // // Commands must be thread safe as they can be recorded into multiple command // buffers concurrently on different stream executors. - -using ResourceUseVector = absl::InlinedVector; - +// +// IMPORTANT: In contrast to GPU thunks, commands MUST be stateless. Thunk state +// typically belongs to the Thunk instance itself, and tends to be kept in +// synchronized hash maps keyed by `se::StreamExecutor*` pointer. Commands on +// the other hand should attach state to the underlying command buffer, and +// because the number of command buffers that can be instantiated from a command +// sequence is unbounded (as we have an eviction policy for command buffers), +// keeping a state in a map inside the command will lead to memory leaks. +// +// Commands have an external state manager, which is responsible for managing +// the lifetime of command state. See `State` and `StateManager` classes below. +// +// To make command stateful, it needs a `params.state` indirection: +// +// class MyCommand : public CommandBufferCmd { +// public: +// +// // Container for mutable state required for command execution. +// struct MyState : CommandBufferCmd::State { +// ... +// }; +// +// absl::StatusOr Record(...) override { +// // Attach a new instance of `MyState` to the `command_buffer`. When +// // command buffer will be destroyed, the state will be destroyed as +// // well automatically by XLA runtime. If this command will be recorded +// // into another command buffer, the state will be re-created +// // automatically using the provided callback. +// MyState* my_state = record_params.state.GetOrCreate(this, +// command_buffer, [&] { // create MyState for a `command_buffer` }); +// ... +// } +// +// }; +// class CommandBufferCmd { public: explicit CommandBufferCmd( @@ -156,6 +190,8 @@ class CommandBufferCmd { // Externally managed state (owned and synchronized by CommandBufferThunk) // allows commands to attach a piece of information to command buffer in a // safe and performant way. + // + // See example above next to `CommandBufferCmd` definition. class State { public: virtual ~State() = default; From d4e51d88fb7334d79da06cdb40f5006478c44157 Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Tue, 9 Dec 2025 00:25:12 -0800 Subject: [PATCH 060/753] Replace std::copy with absl::c_copy for readability. This change refactors usages of std::copy(container.begin(), container.end(), ...) to the more compact absl::c_copy(container, ...). This improves readability and reduces verbosity. Necessary includes for absl/algorithm/container.h have been added where required. PiperOrigin-RevId: 842102140 --- .../gpu/runtime/select_k_exec_raft_test.cc | 2 +- third_party/xla/xla/service/BUILD | 11 ++++++++--- .../xla/xla/service/cpu/onednn_memory_util.cc | 3 +-- third_party/xla/xla/service/gpu/transforms/BUILD | 1 + .../xla/service/gpu/transforms/async_wrapper.cc | 7 +++---- .../xla/service/gpu/triton_tiling_propagation.cc | 10 +++++----- third_party/xla/xla/service/hlo_sharding_test.cc | 10 ++++++++-- third_party/xla/xla/service/shape_inference.cc | 15 ++++++--------- .../xla/xla/service/triangular_solve_expander.cc | 7 +++---- third_party/xla/xla/stream_executor/dnn.cc | 6 ++---- .../xla/stream_executor/tpu/c_api_conversions.cc | 5 ++--- 11 files changed, 40 insertions(+), 37 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft_test.cc b/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft_test.cc index 623a0dd23c61f9..5cb759a96f3f01 100644 --- a/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/select_k_exec_raft_test.cc @@ -106,7 +106,7 @@ void RunSelectKTest() { std::vector h_data_in(batch * n); for (int j = 0; j < batch; ++j) { std::shuffle(topk.begin(), topk.end(), gen); - std::copy(topk.begin(), topk.end(), h_data_in.begin() + j * n); + absl::c_copy(topk, h_data_in.begin() + j * n); } // Compute golden Top-K values for verification diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 52e39430e85f9d..b5d097d79b4715 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -940,6 +940,9 @@ xla_cc_test( name = "hlo_sharding_test", srcs = ["hlo_sharding_test.cc"], deps = [ + "//xla:array", + "//xla:array3d", + "//xla:array4d", "//xla:shape_tree", "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -950,8 +953,11 @@ xla_cc_test( "//xla/hlo/testlib:test_helpers", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "//xla/tsl/util/proto:proto_matchers", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/hash", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) @@ -2280,7 +2286,6 @@ cc_library( hdrs = ["triangular_solve_expander.h"], deps = [ ":hlo_creation_utils", - ":hlo_module_config", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", @@ -2292,14 +2297,14 @@ cc_library( "//xla/hlo/builder/lib:slicing", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/expanders:op_expander_pass", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.cc b/third_party/xla/xla/service/cpu/onednn_memory_util.cc index 2233c84d814d38..94cf1d16e8cfac 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.cc @@ -59,8 +59,7 @@ MemrefInfoHandler CreateMemrefFromShape(const Shape& shape, const void* buf) { result->dtype = shape.element_type(); result->rank = shape.dimensions().size(); auto dimensions = shape.dimensions(); - std::copy(dimensions.begin(), dimensions.end(), - absl::MakeSpan(result->dims).begin()); + absl::c_copy(dimensions, absl::MakeSpan(result->dims).begin()); int64_t stride = 1; for (int i : shape.layout().minor_to_major()) { diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index e76b5f9ee12da0..e4ed891c6391b9 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -160,6 +160,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/tsl/platform:errors", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc index a96615edca6eb3..7529a1c4be4d45 100644 --- a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc @@ -15,10 +15,10 @@ limitations under the License. #include "xla/service/gpu/transforms/async_wrapper.h" -#include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -70,9 +70,8 @@ absl::StatusOr AsyncWrapper::RunImpl( // instructions that can potentially be made async. if (HloPredicateIsOp(instruction)) { - std::copy(instruction->called_computations().begin(), - instruction->called_computations().end(), - std::back_inserter(computations)); + absl::c_copy(instruction->called_computations(), + std::back_inserter(computations)); } } } diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 437a2269739cf7..926307d6f0c0e7 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/service/gpu/triton_tiling_propagation.h" -#include #include #include #include @@ -260,7 +259,9 @@ TensorIterationSpec DimensionOrder::ToTensorIterationSpec() const { // We should not remove the only fragment in a dimension, because if it is // removed, the dimension will be removed from the TensorIterationSpec. - if (dim_spec.size() <= 1) continue; + if (dim_spec.size() <= 1) { + continue; + } TensorIterationSpec::DimIterationSpec filtered_dim_spec; absl::c_copy_if(dim_spec, std::back_inserter(filtered_dim_spec), @@ -575,9 +576,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( std::vector& dst = dst_dim_fragment_orders[dim_index]; dst.reserve(dim_sequence.size()); for (const int src : dim_sequence) { - std::copy(src_to_dst[&src_fragments_order[src]].cbegin(), - src_to_dst[&src_fragments_order[src]].cend(), - std::back_inserter(dst)); + absl::c_copy(src_to_dst[&src_fragments_order[src]], + std::back_inserter(dst)); } } diff --git a/third_party/xla/xla/service/hlo_sharding_test.cc b/third_party/xla/xla/service/hlo_sharding_test.cc index ee87360d9c2c2c..bb97ca0578e0a7 100644 --- a/third_party/xla/xla/service/hlo_sharding_test.cc +++ b/third_party/xla/xla/service/hlo_sharding_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include @@ -21,16 +20,23 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/hash/hash.h" +#include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/array.h" +#include "xla/array3d.h" +#include "xla/array4d.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/test.h" #include "xla/hlo/testlib/test_helpers.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/tsl/util/proto/proto_matchers.h" #include "xla/xla_data.pb.h" @@ -42,7 +48,7 @@ using ::tsl::proto_testing::EqualsProto; Array MakeArray(absl::Span dimensions, absl::Span contents) { Array a(dimensions); - std::copy(contents.begin(), contents.end(), a.begin()); + absl::c_copy(contents, a.begin()); return a; } diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index b5dc87cbf44ee7..8bc3ccb4987186 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -2189,22 +2189,19 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { std::vector input_dnums(num_dims); input_dnums[0] = dnums.input_batch_dimension(); input_dnums[1] = dnums.input_feature_dimension(); - std::copy(dnums.input_spatial_dimensions().begin(), - dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); + absl::c_copy(dnums.input_spatial_dimensions(), input_dnums.begin() + 2); absl::c_sort(input_dnums); std::vector window_dnums(num_dims); window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension(); - std::copy(dnums.kernel_spatial_dimensions().begin(), - dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); + absl::c_copy(dnums.kernel_spatial_dimensions(), window_dnums.begin() + 2); absl::c_sort(window_dnums); std::vector output_dnums(num_dims); output_dnums[0] = dnums.output_batch_dimension(); output_dnums[1] = dnums.output_feature_dimension(); - std::copy(dnums.output_spatial_dimensions().begin(), - dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); + absl::c_copy(dnums.output_spatial_dimensions(), output_dnums.begin() + 2); absl::c_sort(output_dnums); std::vector expected_dnums(num_dims); @@ -3590,9 +3587,9 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { std::vector dimensions(operand.dimensions().size() + broadcast_sizes.size()); - std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin()); - std::copy(operand.dimensions().begin(), operand.dimensions().end(), - dimensions.begin() + broadcast_sizes.size()); + absl::c_copy(broadcast_sizes, dimensions.begin()); + absl::c_copy(operand.dimensions(), + dimensions.begin() + broadcast_sizes.size()); TF_ASSIGN_OR_RETURN(Shape result, ShapeUtil::MakeValidatedShape( operand.element_type(), dimensions)); diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index 049249aa5b0481..5c8577a47eca98 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -36,13 +37,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla { @@ -120,7 +119,7 @@ XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) { TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks)); auto shape_dims = blocks_shape.dimensions(); auto last_blocks_dims = std::vector(ndims); - std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin()); + absl::c_copy(shape_dims, last_blocks_dims.begin()); last_blocks_dims.insert(last_blocks_dims.end() - 2, 1); last_blocks = Reshape(last_blocks, last_blocks_dims); diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index f38a2597972d75..d837220e6f4fe4 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -691,8 +691,7 @@ std::vector BatchDescriptor::full_dims( std::vector bdyx_dims(ndims() + 2); bdyx_dims[0] = count(); bdyx_dims[1] = feature_map_count(); - std::copy(spatial_size().begin(), spatial_size().end(), - bdyx_dims.begin() + 2); + absl::c_copy(spatial_size(), bdyx_dims.begin() + 2); return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout); } @@ -831,8 +830,7 @@ std::vector FilterDescriptor::full_dims( std::vector oiyx_dims(ndims() + 2); oiyx_dims[0] = output_feature_map_count(); oiyx_dims[1] = input_feature_map_count(); - std::copy(input_filter_dims().begin(), input_filter_dims().end(), - oiyx_dims.begin() + 2); + absl::c_copy(input_filter_dims(), oiyx_dims.begin() + 2); return ReorderDims(oiyx_dims, FilterLayout::kOutputInputYX, layout); } diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc index cb53f7c79336dc..b4aefe96cc6995 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/stream_executor/tpu/c_api_conversions.h" -#include #include #include #include @@ -66,9 +65,9 @@ static void CreateVectorBase(const absl::Span src, DstList* dst) { dst->size = src.size(); if (dst->size > TPU_C_API_MAX_INLINED) { dst->heap = new Dst[dst->size]; - std::copy(src.begin(), src.end(), dst->heap); + absl::c_copy(src, dst->heap); } else { - std::copy(src.begin(), src.end(), dst->inlined); + absl::c_copy(src, dst->inlined); } } From 06a58e00b8afa55222f4efaeadd4770b3c712b9e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 00:29:19 -0800 Subject: [PATCH 061/753] Automated Code Change PiperOrigin-RevId: 842103155 --- tensorflow/lite/tools/utils.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/lite/tools/utils.cc b/tensorflow/lite/tools/utils.cc index 6173ec1b112203..96b8bf8689e610 100644 --- a/tensorflow/lite/tools/utils.cc +++ b/tensorflow/lite/tools/utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive From 51058956fec5e3eda3d2e81338b26244378f2bcf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:03:12 -0800 Subject: [PATCH 062/753] compat: Update forward compatibility horizon to 2025-12-09 PiperOrigin-RevId: 842113241 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 019f2360af662e..949094ad18d927 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 12, 8) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 12, 9) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From e534b1654b447e26e2e22e78afe236067cc46d62 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:04:04 -0800 Subject: [PATCH 063/753] Update GraphDef version to 2436. PiperOrigin-RevId: 842113538 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 5448bf12c3dcfe..b483429b89ccd9 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -93,7 +93,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 2435 // Updated: 2025/12/8 +#define TF_GRAPH_DEF_VERSION 2436 // Updated: 2025/12/9 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 6cfe46b1e50d9be83db107c9fac159775e74b6a7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:17:28 -0800 Subject: [PATCH 064/753] Reverts dbd604a06c501bb6dcfe9448a4582ef586539855 PiperOrigin-RevId: 842117724 --- third_party/xla/xla/debug_options_flags.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index fe8c14f18dbd8c..e7473940048f0a 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -416,7 +416,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_autotune_gemm_rtol(0.1f); - opts.set_xla_enable_command_buffers_during_profiling(true); + // TODO(b/355487968): Remove this flag once all data will be presented in + // xprof with command buffers. + opts.set_xla_enable_command_buffers_during_profiling(false); opts.set_xla_gpu_cudnn_gemm_max_plans(5); From b722a687c6faa1aff94584b6cb3cdefc69d2256f Mon Sep 17 00:00:00 2001 From: Alex Pivovarov Date: Tue, 9 Dec 2025 01:24:24 -0800 Subject: [PATCH 065/753] Add nullptr comparison and boolean conversion to MaybeOwning. This change introduces `operator==`, `operator!=` for comparing `MaybeOwning` with `nullptr_t`, and an `explicit operator bool()` to `MaybeOwning`. These allow for more idiomatic checks against null. Updated several call sites to use these new operators, simplifying expressions like `obj.get() == nullptr` to `obj == nullptr` or `obj.get() != nullptr` to `obj`. PiperOrigin-RevId: 842119942 --- .../xla/backends/cpu/nanort/ifrt_client.cc | 2 +- third_party/xla/xla/hlo/ir/hlo_module.cc | 2 +- third_party/xla/xla/maybe_owning.h | 26 +++++++++++++++++++ .../xla/xla/service/gpu/gpu_compiler.cc | 6 ++--- .../xla/xla/tsl/concurrency/async_value_ref.h | 2 +- 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc index 733dd00f6eb2f7..f3731f64821230 100644 --- a/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc +++ b/third_party/xla/xla/backends/cpu/nanort/ifrt_client.cc @@ -511,7 +511,7 @@ class NanoArray final : public NanoValue { OwnedDataPtr owned_data( tsl::port::AlignedMalloc(std::max(size, Align()), Align()), [](void* ptr) { tsl::port::AlignedFree(ptr); }); - if (ABSL_PREDICT_FALSE(owned_data.get() == nullptr)) { + if (ABSL_PREDICT_FALSE(owned_data == nullptr)) { return Internal("Failed to allocate memory for NanoArray. Errno: %s", strerror(errno)); } diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 4656f2a442ca91..fe9502af64f15c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -1751,7 +1751,7 @@ void HloModule::OriginalValueRecoveryTable::AddRecoveryComputation( std::optional* new_original_array = new_inst->original_value()->mutable_original_array(shape_index); if (!*new_original_array) { - if (recovery_computation->get() == nullptr) { + if (*recovery_computation == nullptr) { // If the recovery computation is a nullptr, it means this is an // identity computation and we can just pass through the original array. new_original_array->emplace(*old_original_array); diff --git a/third_party/xla/xla/maybe_owning.h b/third_party/xla/xla/maybe_owning.h index 2b63a45543375d..04bd39a670bea3 100644 --- a/third_party/xla/xla/maybe_owning.h +++ b/third_party/xla/xla/maybe_owning.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_MAYBE_OWNING_H_ #define XLA_MAYBE_OWNING_H_ +#include #include #include @@ -76,6 +77,31 @@ class MaybeOwning final { bool OwnsPtr() const { return kOwningBitMask & ptr_and_owning_bit_; } + friend bool operator==(const MaybeOwning& mo, std::nullptr_t) { + // A MaybeOwning is considered null if its internal pointer is null. + // The get() method correctly removes the mask and returns the raw pointer. + return mo.get() == nullptr; + } + + friend bool operator==(std::nullptr_t, const MaybeOwning& mo) { + // Maintain symmetry for the comparison order + return mo.get() == nullptr; + } + + friend bool operator!=(const MaybeOwning& mo, std::nullptr_t) { + return mo.get() != nullptr; + } + + friend bool operator!=(std::nullptr_t, const MaybeOwning& mo) { + return mo.get() != nullptr; + } + + explicit operator bool() const { + // The class is considered 'true' if the underlying pointer is not null. + // We use the existing get() method, which correctly handles the mask. + return get() != nullptr; + } + private: enum : uint64_t { kOwningBitMask = 1UL, diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 55c74ba56d3ecb..a98c18b53adc89 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -56,6 +56,7 @@ limitations under the License. #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/SplitModule.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "google/protobuf/text_format.h" #include "xla/backends/cpu/nanort/nanort_client.h" @@ -69,7 +70,6 @@ limitations under the License. #include "xla/core/host_offloading/hlo_host_device_type_call_wrapper.h" #include "xla/core/host_offloading/host_compute_asyncifier.h" #include "xla/hlo/analysis/alias_info.h" -#include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -2224,7 +2224,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( // function per module. If caching is not used limit the number of modules to // the number of threads. int num_modules = CountFunctions(*llvm_module); - if (thread_pool.get() != nullptr && !use_cache) { + if (thread_pool && !use_cache) { num_modules = std::max(1, std::min(thread_pool->NumThreads(), num_modules)); } if (compile_module_results.llvm_module_constants != nullptr) { @@ -2262,7 +2262,7 @@ absl::StatusOr GpuCompiler::CompileAndLink( absl::StatusOr result; }; std::vector compile_results(llvm_modules.size()); - if (thread_pool.get() != nullptr) { + if (thread_pool) { absl::BlockingCounter counter(llvm_modules.size()); for (int i = 0; i < llvm_modules.size(); ++i) { thread_pool.get_mutable()->Schedule( diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h index 83825c973a4e5f..1d72dbb4cd05b0 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h @@ -357,7 +357,7 @@ class AsyncValueRef { SetError(absl::InternalError(message_view)); } - explicit operator bool() const { return value_.get() != nullptr; } + explicit operator bool() const { return value_ != nullptr; } bool operator==(const AsyncValueRef& r) const { return value_ == r.value_; } bool operator!=(const AsyncValueRef& r) const { return value_ != r.value_; } From 7ccf06445fa083e296d135a6ca37753ffbf15e06 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:28:36 -0800 Subject: [PATCH 066/753] Automated Code Change PiperOrigin-RevId: 842121699 --- tensorflow/cc/BUILD | 1 + tensorflow/cc/gradients/array_grad.cc | 2 ++ tensorflow/cc/gradients/image_grad.cc | 1 + tensorflow/cc/gradients/math_grad.cc | 1 + tensorflow/cc/gradients/nn_grad.cc | 1 + 5 files changed, 6 insertions(+) diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index bfa665a09f7588..3131284b4802bd 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -359,6 +359,7 @@ cc_library( "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index f3c3fd045a3d6f..f0189c60c714e1 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -14,9 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/ops/array_ops_internal.h" diff --git a/tensorflow/cc/gradients/image_grad.cc b/tensorflow/cc/gradients/image_grad.cc index deb90eec264ee7..bb37c90b3f32a8 100644 --- a/tensorflow/cc/gradients/image_grad.cc +++ b/tensorflow/cc/gradients/image_grad.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "absl/status/status.h" diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index c785af15f95447..af39009ad3f2a5 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/status.h" diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 6309080492c1da..9b980bd9e8321d 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include From 25ed31c1a6c0a135239618b6d885385a6d79701b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:49:00 -0800 Subject: [PATCH 067/753] Automated Code Change PiperOrigin-RevId: 842128029 --- .../xla/backends/gpu/codegen/triton/xtile_compiler_stub_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_compiler_stub_test.cc b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_compiler_stub_test.cc index 5b47611387f732..12216068683e2d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/xtile_compiler_stub_test.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/xtile_compiler_stub_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" From 41fddcf21863db7f7dbb855a5867858f8726d90a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 01:57:02 -0800 Subject: [PATCH 068/753] Automated Code Change PiperOrigin-RevId: 842130465 --- .../xla/xla/backends/autotuner/file_based_autotuner_cache.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/xla/xla/backends/autotuner/file_based_autotuner_cache.cc b/third_party/xla/xla/backends/autotuner/file_based_autotuner_cache.cc index 969286250aa1e9..d9aee5dcd69dd8 100644 --- a/third_party/xla/xla/backends/autotuner/file_based_autotuner_cache.cc +++ b/third_party/xla/xla/backends/autotuner/file_based_autotuner_cache.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" From f90cdb544e97accf09780a2c8810cab6a3a305fb Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Tue, 9 Dec 2025 02:01:31 -0800 Subject: [PATCH 069/753] PR #34898: [GPU] Do not float-normalize bf16 negation and abs. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34898 📝 Summary of Changes Avoid unnecessary type casts - bf16 negation and abs are supported in PTX. 🚀 Kind of Contribution ♻️ Cleanup 🧪 Unit Tests: yes 🧪 Execution Tests: no Copybara import of the project: -- 867f131cccba2df2cbc61d584ebc238cb0aceeae by Ilia Sergachev : [GPU] Do not float-normalize bf16 negation and abs. Merging this change closes #34898 PiperOrigin-RevId: 842132075 --- .../xla/xla/service/gpu/gpu_float_support.cc | 2 ++ .../xla/service/gpu/gpu_float_support_test.cc | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.cc b/third_party/xla/xla/service/gpu/gpu_float_support.cc index cb0477bf19b9a9..6aa7e4b1ec1f68 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support.cc @@ -131,8 +131,10 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { return compute_capability_.IsCuda(); } return false; + case HloOpcode::kAbs: case HloOpcode::kMaximum: case HloOpcode::kMinimum: + case HloOpcode::kNegate: if (LowPrecisionType() == BF16) { auto* cuda_compute_capability = compute_capability_.cuda_compute_capability(); diff --git a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc index f464b670a57701..bd88890113d1a5 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc @@ -432,6 +432,35 @@ ENTRY main { se::GpuComputeCapability{se::CudaComputeCapability::Volta()}, BF16, F32)); } +class Bf16UnaryOpTest : public FloatSupportTest, + public ::testing::WithParamInterface {}; + +TEST_P(Bf16UnaryOpTest, IsOnlyNormalizedPreAmpere) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule( + absl::Substitute(R"( +entry { + a = bf16[] parameter(0) + r = bf16[] $0(a) +})", + HloOpcodeString(GetParam())))); + EXPECT_FALSE( + Normalize(module.get(), + se::GpuComputeCapability{se::CudaComputeCapability::Hopper()}, + BF16, F32)); + EXPECT_FALSE( + Normalize(module.get(), + se::GpuComputeCapability{se::CudaComputeCapability::Ampere()}, + BF16, F32)); + EXPECT_TRUE(Normalize( + module.get(), + se::GpuComputeCapability{se::CudaComputeCapability::Volta()}, BF16, F32)); +} + +INSTANTIATE_TEST_SUITE_P(Bf16UnaryOps, Bf16UnaryOpTest, + ::testing::Values(HloOpcode::kNegate, + HloOpcode::kAbs)); + TEST_F(FloatSupportTest, BF16ReductionOnHopperIsOnlyNormalizedIfReducerIsUnsupported) { auto cc = se::CudaComputeCapability::Hopper(); From a337f310754bf04bb0d69706cec11380a80d8485 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 9 Dec 2025 02:19:24 -0800 Subject: [PATCH 070/753] PR #34964: Bump github/codeql-action from 4.31.6 to 4.31.7 Imported from GitHub PR https://github.com/openxla/xla/pull/34964 Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.31.6 to 4.31.7.
Release notes

Sourced from github/codeql-action's releases.

v4.31.7

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

4.31.7 - 05 Dec 2025

  • Update default CodeQL bundle version to 2.23.7. #3343

See the full CHANGELOG.md for more information.

Changelog

Sourced from github/codeql-action's changelog.

CodeQL Action Changelog

See the releases page for the relevant changes to the CodeQL CLI and language packs.

[UNRELEASED]

No user facing changes.

4.31.7 - 05 Dec 2025

  • Update default CodeQL bundle version to 2.23.7. #3343

4.31.6 - 01 Dec 2025

No user facing changes.

4.31.5 - 24 Nov 2025

  • Update default CodeQL bundle version to 2.23.6. #3321

4.31.4 - 18 Nov 2025

No user facing changes.

4.31.3 - 13 Nov 2025

  • CodeQL Action v3 will be deprecated in December 2026. The Action now logs a warning for customers who are running v3 but could be running v4. For more information, see Upcoming deprecation of CodeQL Action v3.
  • Update default CodeQL bundle version to 2.23.5. #3288

4.31.2 - 30 Oct 2025

No user facing changes.

4.31.1 - 30 Oct 2025

  • The add-snippets input has been removed from the analyze action. This input has been deprecated since CodeQL Action 3.26.4 in August 2024 when this removal was announced.

4.31.0 - 24 Oct 2025

  • Bump minimum CodeQL bundle version to 2.17.6. #3223
  • When SARIF files are uploaded by the analyze or upload-sarif actions, the CodeQL Action automatically performs post-processing steps to prepare the data for the upload. Previously, these post-processing steps were only performed before an upload took place. We are now changing this so that the post-processing steps will always be performed, even when the SARIF files are not uploaded. This does not change anything for the upload-sarif action. For analyze, this may affect Advanced Setup for CodeQL users who specify a value other than always for the upload input. #3222

4.30.9 - 17 Oct 2025

  • Update default CodeQL bundle version to 2.23.3. #3205
  • Experimental: A new setup-codeql action has been added which is similar to init, except it only installs the CodeQL CLI and does not initialize a database. Do not use this in production as it is part of an internal experiment and subject to change at any time. #3204

4.30.8 - 10 Oct 2025

No user facing changes.

... (truncated)

Commits
  • cf1bb45 Merge pull request #3344 from github/update-v4.31.7-f5c63fadd
  • f4ebe95 Update changelog for v4.31.7
  • f5c63fa Merge pull request #3343 from github/update-bundle/codeql-bundle-v2.23.7
  • a2c01e7 Add changelog note
  • ac34c13 Update default bundle to codeql-bundle-v2.23.7
  • 267c467 Merge pull request #3339 from github/dependabot/npm_and_yarn/npm-minor-77d264...
  • aeabef7 Merge branch 'main' into dependabot/npm_and_yarn/npm-minor-77d26487b0
  • 78357d3 Merge pull request #3341 from github/mbg/ci/update-cs-config-cli-tests
  • d61a6fa Update CLI config test to account for overlay db changes on PRs
  • ce27e95 Rebuild
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/codeql-action&package-manager=github_actions&previous-version=4.31.6&new-version=4.31.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Copybara import of the project: -- 0321b497362c7e4020514d78607ca2d0069f6c89 by dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>: Bump github/codeql-action from 4.31.6 to 4.31.7 Bumps [github/codeql-action](https://github.com/github/codeql-action) from 4.31.6 to 4.31.7. - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/fe4161a26a8629af62121b670040955b330f9af2...cf1bb45a277cb3c205638b2cd5c984db1c46a412) --- updated-dependencies: - dependency-name: github/codeql-action dependency-version: 4.31.7 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Merging this change closes #34964 PiperOrigin-RevId: 842137219 --- third_party/xla/.github/workflows/scorecards-analysis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xla/.github/workflows/scorecards-analysis.yml b/third_party/xla/.github/workflows/scorecards-analysis.yml index f781a8bcb93b8a..d2bf9a77ef7ab6 100644 --- a/third_party/xla/.github/workflows/scorecards-analysis.yml +++ b/third_party/xla/.github/workflows/scorecards-analysis.yml @@ -67,6 +67,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@fe4161a26a8629af62121b670040955b330f9af2 # v4.31.6 + uses: github/codeql-action/upload-sarif@cf1bb45a277cb3c205638b2cd5c984db1c46a412 # v4.31.7 with: sarif_file: results.sarif From aa5af21c0e39c1a91ed074f846aa8160fd5521af Mon Sep 17 00:00:00 2001 From: Terry Sun Date: Tue, 9 Dec 2025 02:46:25 -0800 Subject: [PATCH 071/753] PR #34864: Update link in GPU flag guidance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34864 📝 Summary of Changes Update link in GPU flag guidance. 🎯 Justification The original linked page is moved, need update. 🚀 Kind of Contribution 📚 Documentation 📊 Benchmark (for Performance Improvements) N/A. 🧪 Unit Tests: N/A. 🧪 Execution Tests: N/A. Copybara import of the project: -- 27abae70ae87bc005166ddbedfb2c1a0bd15f3f8 by Terry Sun : update link Merging this change closes #34864 PiperOrigin-RevId: 842145979 --- third_party/xla/docs/flags_guidance.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/xla/docs/flags_guidance.md b/third_party/xla/docs/flags_guidance.md index f887b66c4fbc50..c973a1a665ca7c 100644 --- a/third_party/xla/docs/flags_guidance.md +++ b/third_party/xla/docs/flags_guidance.md @@ -79,8 +79,7 @@ data-parallel collectives (`xla_gpu_enable_pipelined_all_gather`, (`xla_gpu_enable_while_loop_double_buffering`), latency hiding scheduling (`xla_gpu_enable_latency_hiding_scheduler`), and SOL latency estimator on Hopper/Blackwell (`xla_gpu_enable_analytical_sol_latency_estimator`). See -[GPU Optimization Levels](https://openxla.org/xla/gpu_optimization_levels) for -details. +[GPU Effort Levels](https://openxla.org/xla/effort_levels) for details. | Flag | Type | Notes | | :---- | :---- | :----- | From 2a5ffea6a109dbb85951060a83187aa8251c3097 Mon Sep 17 00:00:00 2001 From: spiao Date: Tue, 9 Dec 2025 02:48:20 -0800 Subject: [PATCH 072/753] PR #34806: [ROCm] fix the calling convention for AMD GPU Imported from GitHub PR https://github.com/openxla/xla/pull/34806 Bugfix: PR #34230 ("argument removal without building prototype") removed the call to **BuildKernelPrototypeFromUniqueName** which internally called **AnnotateFunctionAsGpuKernel** to set the correct calling convention based on the target GPU. Without this, Triton's **PTX_Kernel** calling convention was copied directly, which doesn't work on AMD GPUs and lead to "LLVM ERROR: unsupported calling convention". Fix: Added a call to **AnnotateFunctionAsGpuKernel** in **RemoveUnusedTritonAbiArguments** to properly set: PTX_Kernel (71) for NVIDIA AMDGPU_KERNEL (91) for AMD SPIR_KERNEL (76) for SPIR @xla-rotation could you review my PR, please? Copybara import of the project: -- ebd6e1fa03033bc9f6913351323fce26e1a8e4d2 by Songlin Piao : replace the manual calling convention fix with AnnotateFunctionAsGpuKernel -- 4f16d9579b11c2984c8ebe58041b0d2b9ea5ba3f by Songlin Piao : added a filecheck test Merging this change closes #34806 PiperOrigin-RevId: 842146580 --- .../backends/gpu/codegen/fusion_emitter.cc | 7 ++++- third_party/xla/xla/service/gpu/tests/BUILD | 3 +++ .../gpu/tests/triton_calling_convention.hlo | 26 +++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 third_party/xla/xla/service/gpu/tests/triton_calling_convention.hlo diff --git a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc index ad12a2d0923948..d051a6daf3e778 100644 --- a/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc +++ b/third_party/xla/xla/backends/gpu/codegen/fusion_emitter.cc @@ -266,10 +266,15 @@ absl::StatusOr RemoveUnusedTritonAbiArguments( .getCallee(); llvm::Function* new_function = static_cast(inserted); - new_function->setCallingConv(impl_fn->getCallingConv()); new_function->copyMetadata(impl_fn, 0); new_function->setAttributes(impl_fn->getAttributes()); + // Set the correct calling convention for the target GPU. + // Triton generates PTX_Kernel CC even for AMD, so we need to use + // AnnotateFunctionAsGpuKernel to set the correct CC based on target triple. + llvm::IRBuilder<> builder(llvm_module->getContext()); + AnnotateFunctionAsGpuKernel(llvm_module, new_function, &builder); + new_function->splice(new_function->begin(), impl_fn); for (const auto& [impl_fn_arg, kernel_arg] : diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 55c8c2316833fd..284f4090f3fe0a 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -661,6 +661,7 @@ lit_test_suite_for_gpus( "slice_to_dynamic.hlo", "sorting.hlo", "sub_byte_collectives.hlo", + "triton_calling_convention.hlo", "triton_naming.hlo", "zero_clamp_abs_index.hlo", ], @@ -673,10 +674,12 @@ lit_test_suite_for_gpus( disabled_on_gpus = { "v100": [ "kernel_reuse.hlo", + "triton_calling_convention.hlo", "triton_naming.hlo", ], "p100": [ "kernel_reuse.hlo", + "triton_calling_convention.hlo", "triton_naming.hlo", ], "mi200": [ diff --git a/third_party/xla/xla/service/gpu/tests/triton_calling_convention.hlo b/third_party/xla/xla/service/gpu/tests/triton_calling_convention.hlo new file mode 100644 index 00000000000000..6a83c444793d47 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/triton_calling_convention.hlo @@ -0,0 +1,26 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s + +// Verify that Triton kernels have the correct calling convention: +// - PTX_KERNEL (71) for NVIDIA targets +// - AMDGPU_KERNEL (91) for AMD targets +// CHECK-PTX: define ptx_kernel void @triton_ +// CHECK-GCN: define amdgpu_kernel void @triton_ + +HloModule TritonCallingConvention, is_scheduled=true + +triton_softmax { + param_0 = f32[4,4]{1,0} parameter(0) + ROOT exp = f32[4,4]{1,0} exponential(param_0) +} + +ENTRY main { + param_0 = f32[4,4]{1,0} parameter(0) + ROOT triton_softmax = f32[4,4]{1,0} fusion(param_0), kind=kCustom, + calls=triton_softmax, + backend_config={"fusion_backend_config":{ + "kind":"__triton", + "block_level_fusion_config":{"output_tiles":[{"sizes":["4","4"]}], + "num_warps":"1", + "num_ctas":"1", + "num_stages":"1"}}} +} From aa202263bdaa236f07f01c6a95aafb72a0c65251 Mon Sep 17 00:00:00 2001 From: Kanish Anand Date: Tue, 9 Dec 2025 02:54:53 -0800 Subject: [PATCH 073/753] Add constructor to `NamedSharding` accepting axis names. This provides a more intuitive way to create `NamedSharding` objects, especially in tests, as it's easier to work with human-readable axis names than with `AxisRef` indices. PiperOrigin-RevId: 842148670 --- third_party/xla/xla/hlo/ir/BUILD | 2 + third_party/xla/xla/hlo/ir/named_sharding.cc | 81 +++++++++++++++++++ third_party/xla/xla/hlo/ir/named_sharding.h | 19 ++++- .../xla/xla/hlo/ir/named_sharding_test.cc | 26 ++++++ 4 files changed, 126 insertions(+), 2 deletions(-) create mode 100644 third_party/xla/xla/hlo/ir/named_sharding.cc diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 6d4c270f5b856f..07b2e615d62479 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -203,12 +203,14 @@ cc_library( cc_library( name = "named_sharding", + srcs = ["named_sharding.cc"], hdrs = ["named_sharding.h"], deps = [ ":mesh_and_axis", ":tile_assignment", "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/ir/named_sharding.cc b/third_party/xla/xla/hlo/ir/named_sharding.cc new file mode 100644 index 00000000000000..0db5d8e916aa53 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/named_sharding.cc @@ -0,0 +1,81 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/named_sharding.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/mesh_and_axis.h" + +namespace xla { + +namespace test_utils { +// Construct sharding with given mesh. 'dim_shardings', 'replicated_axes', +// 'unreduced_axes' refer to axis names in the mesh. +// This is a test only helper function. +NamedSharding FromAxisNames( + Mesh mesh, absl::Span> dim_shardings, + absl::Span replicated_axes, + absl::Span unreduced_axes, + absl::Span metadata) { + std::map mesh_axis_to_index; + for (int64_t i = 0; i < mesh.axis_names().size(); ++i) { + mesh_axis_to_index[mesh.axis_names()[i]] = i; + } + + std::vector dim_shardings_; + dim_shardings_.reserve(dim_shardings.size()); + for (const auto& axes_for_dim : dim_shardings) { + std::vector axis_refs; + axis_refs.reserve(axes_for_dim.size()); + for (const std::string& axis_name : axes_for_dim) { + auto it = mesh_axis_to_index.find(axis_name); + CHECK(it != mesh_axis_to_index.end()) + << "Axis " << axis_name << " not found in mesh " << mesh.ToString(); + axis_refs.push_back(AxisRef(it->second)); + } + dim_shardings_.push_back(NamedSharding::DimensionSharding( + std::move(axis_refs), /*is_closed=*/true)); + } + + std::vector replicated_axes_; + replicated_axes_.reserve(replicated_axes.size()); + for (const std::string& axis_name : replicated_axes) { + auto it = mesh_axis_to_index.find(axis_name); + CHECK(it != mesh_axis_to_index.end()) + << "Axis " << axis_name << " not found in mesh " << mesh.ToString(); + replicated_axes_.push_back(AxisRef(it->second)); + } + + std::vector unreduced_axes_; + unreduced_axes_.reserve(unreduced_axes.size()); + for (const std::string& axis_name : unreduced_axes) { + auto it = mesh_axis_to_index.find(axis_name); + CHECK(it != mesh_axis_to_index.end()) + << "Axis " << axis_name << " not found in mesh " << mesh.ToString(); + unreduced_axes_.push_back(AxisRef(it->second)); + } + + return NamedSharding(mesh, dim_shardings_, replicated_axes_, unreduced_axes_, + metadata); +} +} // namespace test_utils +} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/named_sharding.h b/third_party/xla/xla/hlo/ir/named_sharding.h index 6c93bed8d40c74..bfdc9966c0b15d 100644 --- a/third_party/xla/xla/hlo/ir/named_sharding.h +++ b/third_party/xla/xla/hlo/ir/named_sharding.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_IR_NAMED_SHARDING_H_ #include +#include #include #include @@ -38,8 +39,11 @@ class NamedSharding { return axes_ == other.axes_ && is_closed_ == other.is_closed_; } - explicit DimensionSharding(std::vector axes, bool is_closed) - : axes_(std::move(axes)), is_closed_(is_closed) {} + // Note that by default we assume closed sharding. + explicit DimensionSharding() : is_closed_(true) {}; + + explicit DimensionSharding(absl::Span axes, bool is_closed) + : axes_(axes.begin(), axes.end()), is_closed_(is_closed) {} absl::Span axes() const { return axes_; } @@ -118,6 +122,17 @@ class NamedSharding { std::vector metadata_; }; +// Contains test only helper functions. +namespace test_utils { +// Construct sharding with given mesh. 'dim_shardings', 'replicated_axes', +// 'unreduced_axes' refer to axis names in the mesh. +NamedSharding FromAxisNames( + Mesh mesh, absl::Span> dim_shardings, + absl::Span replicated_axes = {}, + absl::Span unreduced_axes = {}, + absl::Span metadata = {}); +} // namespace test_utils + } // namespace xla #endif // XLA_HLO_IR_NAMED_SHARDING_H_ diff --git a/third_party/xla/xla/hlo/ir/named_sharding_test.cc b/third_party/xla/xla/hlo/ir/named_sharding_test.cc index 36e9cfbbba67bb..78e3b3e3b08095 100644 --- a/third_party/xla/xla/hlo/ir/named_sharding_test.cc +++ b/third_party/xla/xla/hlo/ir/named_sharding_test.cc @@ -24,6 +24,32 @@ namespace { using DimensionSharding = NamedSharding::DimensionSharding; +TEST(NamedShardingTest, AxisNameCtor) { + Mesh mesh_abcd({2, 4, 3, 8}, {"a", "b", "c", "d"}); + AxisRef axis_a(0); + AxisRef axis_b(1); + AxisRef axis_c(2); + AxisRef axis_d(3); + + NamedSharding sharding = + test_utils::FromAxisNames(mesh_abcd, /*dim_shardings=*/{{"c"}, {"b"}}, + /*replicated_axes=*/{"a"}, + /*unreduced_axes=*/{"d"}); + DimensionSharding ds_c({axis_c}, /*is_closed=*/true); + DimensionSharding ds_b({axis_b}, /*is_closed=*/true); + EXPECT_EQ(sharding, + NamedSharding(mesh_abcd, {ds_c, ds_b}, {axis_a}, {axis_d})); + + NamedSharding sharding2 = test_utils::FromAxisNames( + mesh_abcd, + /*dim_shardings=*/{{"c", "a"}, {}, {"b"}}, + /*replicated_axes=*/{"d"}, /*unreduced_axes=*/{}); + DimensionSharding ds_ca({axis_c, axis_a}, /*is_closed=*/true); + EXPECT_EQ(sharding2, + NamedSharding(mesh_abcd, {ds_ca, DimensionSharding(), ds_b}, + {axis_d}, {})); +} + TEST(NamedShardingTest, Equality) { Mesh mesh_abcd({2, 4, 3, 8}, {"a", "b", "c", "d"}); From a05e35e330f70fe1920a07573e709247b09ddb15 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 03:05:52 -0800 Subject: [PATCH 074/753] Reverts 0752a12d8a06aaefc942eaf1f5255a6eea23ca14 PiperOrigin-RevId: 842151521 --- .../xla/hlo/analysis/hlo_dataflow_analysis.cc | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc index ca00349f4c25d5..893c233f9bd1c2 100644 --- a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/hlo_value.h" @@ -1616,6 +1617,26 @@ HloDataflowAnalysis::GetInPlaceInputOutputPairs( return alias_info->GetInPlaceInputOutputPairs(instruction); } +// Returns true if the instruction is a fusion consisting of a single copy which +// changes tiling. This is handled by the emitters and effectively are no-ops. +static bool IsChangeTilingCopyFusion(HloInstruction* instr) { + if (!instr->parent()->IsFusionComputation() || + instr->opcode() != HloOpcode::kFusion || + instr->called_computations().size() != 1 || instr->operand_count() != 1) { + return false; + } + // These copy fusions should only change tiling (and sometimes memory space). + HloInstruction* fusion_root = instr->fused_expression_root(); + const Layout& operand_layout = fusion_root->operand(0)->shape().layout(); + const Layout& output_layout = fusion_root->shape().layout(); + absl::Span operand_tiles = operand_layout.tiles(); + absl::Span output_tiles = output_layout.tiles(); + return fusion_root->opcode() == HloOpcode::kCopy && + Layout::Equal().IgnoreTiles().IgnoreMemorySpace()(operand_layout, + output_layout) && + operand_tiles != output_tiles; +} + bool HloDataflowAnalysis::CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, @@ -1631,7 +1652,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); - auto shapes_equal = ShapeUtil::Equal(operand_subshape, user_subshape); + // During tiling assignment, we can add no-op instructions which appear to + // change tiling (and memory space) of the operand, but don't. + if (IsChangeTilingCopyFusion(user) || IsChangeTilingCopyFusion(operand)) { + return true; + } + const bool shapes_equal = ShapeUtil::Equal(operand_subshape, user_subshape); // Check that operand and user emit the same shape and layout. if (shapes_equal) { // Must-alias relationship returns true for in-place operations (DUS and DUS From 10344d0c57913c6abffe86c6dbc5bac8322b19f2 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Tue, 9 Dec 2025 04:32:33 -0800 Subject: [PATCH 075/753] [XLA:GPU] Move default Triton configs to text proto format. This is to make default configuration consistent to what new `--xla_gpu_gemm_autotuner_override_file` flag takes. PiperOrigin-RevId: 842176141 --- .../xla/xla/backends/gpu/autotuner/triton.cc | 8 +- .../xla/xla/service/gpu/autotuning/BUILD | 12 +- .../autotuning/gemm_fusion_autotuner_cuda.cc | 6 +- .../autotuning/gemm_fusion_autotuner_rocm.cc | 2 +- .../service/gpu/autotuning/triton_configs.cc | 207 ++++++++++++++++++ .../service/gpu/autotuning/triton_configs.h | 77 +------ 6 files changed, 235 insertions(+), 77 deletions(-) create mode 100644 third_party/xla/xla/service/gpu/autotuning/triton_configs.cc diff --git a/third_party/xla/xla/backends/gpu/autotuner/triton.cc b/third_party/xla/xla/backends/gpu/autotuner/triton.cc index a6cc696a80ef91..6a6246ee386be9 100644 --- a/third_party/xla/xla/backends/gpu/autotuner/triton.cc +++ b/third_party/xla/xla/backends/gpu/autotuner/triton.cc @@ -60,7 +60,7 @@ namespace { std::vector GetDefaultTritonConfigs( se::GpuComputeCapability compute_capability, bool autotune_tma) { if (compute_capability.IsRocm()) { - return *kDefaultRocmConfigs; + return GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultRocm); } CHECK(compute_capability.IsCuda()); @@ -68,12 +68,12 @@ std::vector GetDefaultTritonConfigs( std::vector configs; if (cuda_compute_capability->IsAtLeastBlackwell()) { - configs = *kBlackwellConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kBlackwell); } else if (cuda_compute_capability->IsHopper() || cuda_compute_capability->IsAmpere()) { - configs = *kHopperAmpereConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kHopperAmpere); } else { - configs = *kDefaultCudaConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultCuda); } if (!autotune_tma) { diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index 56d8c4cb2f64ee..5ca2354d355add 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -758,6 +758,16 @@ cc_library( cc_library( name = "triton_configs", + srcs = ["triton_configs.cc"], hdrs = ["triton_configs.h"], - deps = ["//xla/service/gpu:matmul_utils"], + deps = [ + "//xla:autotuning_proto_cc", + "//xla/service/gpu:matmul_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", + ], ) diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc index 7dc86e8a9c2fde..336b668d4b3160 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_cuda.cc @@ -118,11 +118,11 @@ std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() std::vector configs; if (compute_capability.IsAtLeastBlackwell()) { - configs = *kBlackwellConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kBlackwell); } else if (compute_capability.IsHopper() || compute_capability.IsAmpere()) { - configs = *kHopperAmpereConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kHopperAmpere); } else { - configs = *kDefaultCudaConfigs; + configs = GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultCuda); } if (!debug_options_.xla_gpu_experimental_enable_triton_tma() || diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_rocm.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_rocm.cc index 83232e68d4e126..e7d072f1f0d96e 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_rocm.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_rocm.cc @@ -49,7 +49,7 @@ GemmFusionAutotuner::GetPlatformCodegenBackends( std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() const { - return *kDefaultRocmConfigs; + return GetTritonConfigsForPlatform(TritonConfigsPlatform::kDefaultRocm); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/triton_configs.cc b/third_party/xla/xla/service/gpu/autotuning/triton_configs.cc new file mode 100644 index 00000000000000..e57bb34bf71e97 --- /dev/null +++ b/third_party/xla/xla/service/gpu/autotuning/triton_configs.cc @@ -0,0 +1,207 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/autotuning/triton_configs.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/text_format.h" +#include "xla/autotuning.pb.h" +#include "xla/service/gpu/matmul_utils.h" + +namespace xla { +namespace gpu { +namespace { + +// TODO(b/467265599): Replace string constants with cc_embed_data when +// https://github.com/bazelbuild/rules_cc/issues/41 is fixed. + +constexpr absl::string_view kBlackwellTritonConfigs = R"( +config { block_m: 128 block_n: 128 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 1 num_stages: 1 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 8 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 16 split_k: 512 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 32 split_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 64 split_k: 1 num_stages: 5 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 64 split_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 64 split_k: 64 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 2 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 4 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 64 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 64 split_k: 16 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 64 split_k: 8 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 64 split_k: 8 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 32 block_k: 64 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 16 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 256 block_n: 32 block_k: 32 split_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 32 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 64 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 64 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 16 split_k: 1 num_stages: 1 num_warps: 16 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 16 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 64 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 128 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 64 split_k: 64 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 128 split_k: 8 num_stages: 1 num_warps: 8 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 16 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +)"; + +constexpr absl::string_view kDefaultCudaTritonConfigs = R"( +config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 64 block_k: 128 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 128 block_k: 32 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 16 block_k: 512 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 32 split_k: 1 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 64 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 256 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 128 block_k: 128 split_k: 1 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 256 block_n: 64 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 256 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 64 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 128 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 2 num_stages: 1 num_warps: 8 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 64 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 64 block_k: 256 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 256 block_n: 256 block_k: 128 split_k: 1 num_stages: 3 num_warps: 8 num_ctas: 1 } +)"; + +constexpr absl::string_view kDefaultRocmTritonConfigs = R"( +config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +)"; + +constexpr absl::string_view kHopperAmpereTritonConfigs = R"( +config { block_m: 16 block_n: 16 block_k: 64 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 128 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 16 block_k: 128 split_k: 16 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 16 block_n: 256 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } +config { block_m: 32 block_n: 32 block_k: 128 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 256 block_k: 32 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 32 block_n: 256 block_k: 32 split_k: 16 num_stages: 3 num_warps: 8 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 32 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 32 split_k: 16 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 64 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 64 split_k: 4 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 64 split_k: 16 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 128 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 16 block_k: 128 split_k: 16 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 64 split_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 128 split_k: 1 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 32 block_k: 128 split_k: 128 num_stages: 2 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 64 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 64 split_k: 4 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 128 split_k: 16 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 64 block_k: 256 split_k: 16 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 16 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 64 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 128 block_k: 128 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 64 block_n: 256 block_k: 32 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 32 split_k: 8 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 64 split_k: 16 num_stages: 3 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 16 block_k: 64 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } +config { block_m: 128 block_n: 32 block_k: 32 split_k: 8 num_stages: 4 num_warps: 2 num_ctas: 1 } +config { block_m: 128 block_n: 128 block_k: 32 split_k: 8 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 4 num_warps: 8 num_ctas: 1 } +config { block_m: 64 block_n: 8 block_k: 128 split_k: 2 num_stages: 3 num_warps: 4 num_ctas: 1 } +)"; + +absl::flat_hash_map> +LoadTritonConfigs() { + absl::flat_hash_map> + result; + + auto parse_config = + [](absl::string_view config_str) -> std::vector { + TritonGemmConfigsProto proto; + CHECK(tsl::protobuf::TextFormat::ParseFromString(config_str, &proto)) + << config_str; + std::vector configs; + absl::c_transform(proto.config(), std::back_inserter(configs), + [](const AutotuneResult::TritonGemmKey& config_proto) { + absl::StatusOr config = + TritonGemmConfig::FromProto(config_proto); + CHECK_OK(config); + return *config; + }); + return configs; + }; + + const std::initializer_list< + std::pair> + kConfigsMap = { + {TritonConfigsPlatform::kBlackwell, kBlackwellTritonConfigs}, + {TritonConfigsPlatform::kDefaultCuda, kDefaultCudaTritonConfigs}, + {TritonConfigsPlatform::kDefaultRocm, kDefaultRocmTritonConfigs}, + {TritonConfigsPlatform::kHopperAmpere, kHopperAmpereTritonConfigs}, + }; + for (const auto& [platform, config_str] : kConfigsMap) { + result[platform] = parse_config(config_str); + } + + return result; +} + +} // namespace + +const std::vector& GetTritonConfigsForPlatform( + TritonConfigsPlatform platform) { + static const absl::NoDestructor< + absl::flat_hash_map>> + kConfigs(LoadTritonConfigs()); + return kConfigs->at(platform); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuning/triton_configs.h b/third_party/xla/xla/service/gpu/autotuning/triton_configs.h index 7cb0896477b419..252b4be2b1b692 100644 --- a/third_party/xla/xla/service/gpu/autotuning/triton_configs.h +++ b/third_party/xla/xla/service/gpu/autotuning/triton_configs.h @@ -23,74 +23,15 @@ limitations under the License. namespace xla { namespace gpu { -using Config = TritonGemmConfig; - -static const std::vector* const kBlackwellConfigs = - new std::vector( - {Config(128, 128, 32, 1, 4, 4), Config(128, 128, 64, 1, 1, 8), - Config(128, 128, 64, 8, 3, 4), Config(128, 16, 16, 512, 4, 2), - Config(128, 16, 32, 16, 3, 2), Config(128, 16, 64, 1, 5, 4), - Config(128, 16, 64, 16, 3, 4), Config(128, 16, 64, 64, 1, 2), - Config(128, 256, 64, 1, 4, 8), Config(128, 256, 64, 2, 4, 8), - Config(128, 256, 64, 4, 3, 8), Config(128, 64, 64, 1, 3, 4), - Config(128, 64, 64, 16, 4, 8), Config(128, 64, 64, 8, 4, 4), - Config(16, 16, 128, 1, 3, 2), Config(16, 16, 16, 1, 1, 2), - Config(16, 16, 64, 8, 3, 2), Config(16, 32, 64, 1, 3, 2), - Config(256, 128, 64, 1, 3, 8), Config(256, 16, 16, 1, 1, 2), - Config(256, 32, 32, 16, 3, 4), Config(32, 16, 32, 1, 4, 2), - Config(32, 16, 512, 1, 1, 4), Config(32, 16, 64, 1, 1, 2), - Config(32, 16, 64, 1, 4, 2), Config(64, 128, 16, 1, 1, 16), - Config(64, 128, 16, 1, 3, 2), Config(64, 128, 64, 1, 4, 4), - Config(64, 16, 64, 1, 2, 2), Config(64, 32, 128, 1, 3, 2), - Config(64, 32, 32, 1, 4, 2), Config(64, 32, 64, 64, 3, 2), - Config(64, 64, 128, 8, 1, 8), Config(64, 64, 16, 1, 1, 2), - Config(64, 64, 16, 1, 3, 2)}); - -static const std::vector* const kHopperAmpereConfigs = - new std::vector( - {Config(16, 16, 64, 1, 4, 2), Config(16, 16, 128, 1, 4, 4), - Config(16, 16, 128, 128, 4, 2), Config(16, 16, 128, 16, 1, 2), - Config(16, 256, 16, 1, 1, 2), Config(32, 32, 128, 16, 1, 4), - Config(32, 256, 32, 1, 3, 4), Config(32, 256, 32, 16, 3, 8), - Config(64, 16, 32, 1, 4, 2), Config(64, 16, 32, 16, 4, 2), - Config(64, 16, 64, 1, 1, 4), Config(64, 16, 64, 4, 3, 2), - Config(64, 16, 64, 16, 4, 4), Config(64, 16, 128, 1, 4, 2), - Config(64, 16, 128, 16, 4, 4), Config(64, 32, 32, 1, 4, 4), - Config(64, 32, 64, 16, 3, 4), Config(64, 32, 128, 1, 3, 2), - Config(64, 32, 128, 128, 2, 4), Config(64, 64, 32, 1, 4, 4), - Config(64, 64, 64, 1, 4, 4), Config(64, 64, 64, 4, 4, 4), - Config(64, 64, 128, 16, 3, 4), Config(64, 64, 256, 16, 4, 8), - Config(64, 128, 16, 1, 4, 2), Config(64, 128, 64, 1, 3, 4), - Config(64, 128, 128, 8, 1, 4), Config(64, 256, 32, 1, 4, 4), - Config(128, 16, 32, 8, 4, 2), Config(128, 16, 64, 16, 3, 2), - Config(128, 16, 64, 16, 1, 4), Config(128, 32, 32, 8, 4, 2), - Config(128, 128, 32, 8, 4, 8), Config(128, 256, 32, 1, 4, 8), - Config(128, 256, 64, 1, 4, 8), Config(64, 8, 128, 2, 3, 4, 1)}); - -static const std::vector* const kDefaultCudaConfigs = - new std::vector( - {Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), - Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), - Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), - Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8), - Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4), - Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4), - Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8), - Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4), - Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4), - Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4), - Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4), - Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4), - Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4), - Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)}); - -static const std::vector* const kDefaultRocmConfigs = - new std::vector( - {Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4)}); +enum class TritonConfigsPlatform { + kBlackwell, + kDefaultCuda, + kDefaultRocm, + kHopperAmpere, +}; + +const std::vector& GetTritonConfigsForPlatform( + TritonConfigsPlatform); } // namespace gpu } // namespace xla From 5d3686aeb72d6b9db120f516419663d62cd28e1c Mon Sep 17 00:00:00 2001 From: spiao Date: Tue, 9 Dec 2025 05:00:28 -0800 Subject: [PATCH 076/753] PR #34812: [ROCm] Add register spilling detection support AMD MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imported from GitHub PR https://github.com/openxla/xla/pull/34812 ✨ New Feature Added register spilling detection support. 🧪 Execution Test ./bazel-7.4.1-linux-x86_64 build //xla/service/gpu/transforms:triton_fusion_numerics_verifier_test bazel-bin/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test_amdgpu_any --gtest_filter=TritonFusionNumericsVerifierTest.CompilationSucceedsEvenIfKernelWillSpillRegisters ``` I0000 00:00:1764849271.079538 2923925 amdgpu_backend.cc:447] ====== REGISTER SPILLING DETECTED ====== I0000 00:00:1764849271.079561 2923925 amdgpu_backend.cc:448] Module: triton_softmax_consts I0000 00:00:1764849271.079565 2923925 amdgpu_backend.cc:449] SGPR spill count: 0 I0000 00:00:1764849271.079569 2923925 amdgpu_backend.cc:450] VGPR spill count: 194 I0000 00:00:1764849271.079572 2923925 amdgpu_backend.cc:451] Private segment size: 780 bytes I0000 00:00:1764849271.079574 2923925 amdgpu_backend.cc:452] Performance may be degraded due to register pressure I0000 00:00:1764849271.079576 2923925 amdgpu_backend.cc:453] ======================================== I0000 00:00:1764849271.390972 2923925 amdgpu_backend.cc:447] ====== REGISTER SPILLING DETECTED ====== I0000 00:00:1764849271.390996 2923925 amdgpu_backend.cc:448] Module: triton_softmax_consts I0000 00:00:1764849271.391000 2923925 amdgpu_backend.cc:449] SGPR spill count: 0 I0000 00:00:1764849271.391005 2923925 amdgpu_backend.cc:450] VGPR spill count: 194 I0000 00:00:1764849271.391007 2923925 amdgpu_backend.cc:451] Private segment size: 780 bytes I0000 00:00:1764849271.391009 2923925 amdgpu_backend.cc:452] Performance may be degraded due to register pressure I0000 00:00:1764849271.391012 2923925 amdgpu_backend.cc:453] ======================================== I0000 00:00:1764849271.397868 2923925 tfrt_gpu_client.cc:197] TfrtGpuClient destroyed. [ OK ] TritonFusionNumericsVerifierTest.CompilationSucceedsEvenIfKernelWillSpillRegisters (8019 ms) [----------] 1 test from TritonFusionNumericsVerifierTest (8019 ms total) [----------] Global test environment tear-down [==========] 1 test from 1 test suite ran. (8019 ms total) [ PASSED ] 1 test. ``` This PR is on top of another bugfix PR (https://github.com/openxla/xla/pull/34806). @xla-rotation could you review my PR, please? Copybara import of the project: -- ebd6e1fa03033bc9f6913351323fce26e1a8e4d2 by Songlin Piao : replace the manual calling convention fix with AnnotateFunctionAsGpuKernel -- fafc7f1f6ad5a47204a32d433eab2bc5ec44dbd3 by Songlin Piao : register spilling by disassembling object file -- f6b86f6fc96fd3398608c0078233db2efa74fce7 by Songlin Piao : added time measurement to the spilling check -- 8e5ea8455fc730b73b3768cbdde07079c8c53c29 by Songlin Piao : adapt the num_warps so that the hlo could be compiled on both amd and nvidia -- 22ef808416e6d339356c3a901ce1f5d03a396a60 by Songlin Piao : pass though is_autotuning_compilation flag to the function CompileToHsaco -- b1d5e976c8051332ca1fc45e5f3b91fcd15a3da8 by Songlin Piao : implementation of register spilling by reading meta data of hasco file using llvm-readobj -- d74ae83731a0a56a7285c1ac57689678d21e42d4 by Songlin Piao : adapted functiona calls as is_autotuning_compilation is removed in upstream -- 07ed74d49361fb1945092cac459a3bb70262265b by Songlin Piao : utilize amd code object manager library for parsing HSACO metadata -- 11e83bcb502ee341ddf7db9044b05b4b757ca5e9 by Songlin Piao : Revert "replace the manual calling convention fix with AnnotateFunctionAsGpuKernel" This reverts commit ebd6e1fa03033bc9f6913351323fce26e1a8e4d2. Merging this change closes #34812 PiperOrigin-RevId: 842183737 --- .../xla/third_party/gpus/rocm/BUILD.tpl | 9 +- .../xla/xla/backends/gpu/codegen/triton/BUILD | 1 - .../xla/service/gpu/llvm_gpu_backend/BUILD | 4 + .../gpu/llvm_gpu_backend/amdgpu_backend.cc | 175 +++++++++++++++++- .../triton_fusion_numerics_verifier_test.cc | 2 +- 5 files changed, 185 insertions(+), 6 deletions(-) diff --git a/third_party/xla/third_party/gpus/rocm/BUILD.tpl b/third_party/xla/third_party/gpus/rocm/BUILD.tpl index c95f9a95933fbc..4eba66c971da72 100644 --- a/third_party/xla/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/xla/third_party/gpus/rocm/BUILD.tpl @@ -150,9 +150,11 @@ cc_library( ], ":multiple_rocm_paths": [ "-Wl,-rpath=%{rocm_lib_paths}", + "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ], "//conditions:default": [ "-Wl,-rpath,/opt/rocm/lib", + "-Lexternal/local_config_rocm/rocm/%{rocm_root}/lib", ], }), visibility = ["//visibility:public"], @@ -535,7 +537,7 @@ cc_library( cc_library( name = "amd_comgr", hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), - data = glob([ + srcs = glob([ "%{rocm_root}/lib/libamd_comgr_loader.so*", "%{rocm_root}/lib/libamd_comgr.so*", "%{rocm_root}/lib/llvm/lib/libLLVM.so*", @@ -548,9 +550,12 @@ cc_library( ":build_hermetic": [ "-lamd_comgr_loader", ], - "//conditions:default": [], + "//conditions:default": [ + "-lamd_comgr", + ], }), strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], deps = [ ":rocm_config", ":rocm_rpath", diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index ce2212b64621e9..edd7570f1ed64c 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -413,7 +413,6 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/llvm_gpu_backend:amdgpu_backend", "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 668b495f7626ac..f70483614dd8c9 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -171,7 +171,9 @@ cc_library( "HAS_SUPPORT_FOR_EMBEDDED_LIB_DEVICE=1", ]), tags = [ + "gpu", "nofixdeps", # This target crashes build_cleaner ¯\_(ツ)_/¯ + "rocm-only", ], deps = [ ":llvm_gpu_backend", @@ -210,6 +212,8 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:Target", "@llvm-project//llvm:TargetParser", + "@local_config_rocm//rocm:amd_comgr", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/profiler/lib:traceme", diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc index c927c1eb627731..536216735c9adf 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/amdgpu_backend.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/amdgpu_backend.h" +#include #include #include #include @@ -39,6 +40,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "amd_comgr/amd_comgr.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Analysis/CGSCCPassManager.h" @@ -318,14 +320,151 @@ absl::StatusOr> EmitModuleToHsaco( } } - // Read HSACO. + // Read HSACO file into memory (used for both metadata extraction and return) std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate); + if (!hsaco_file) { + return xla::Internal("Failed to open HSACO file: %s", hsaco_path); + } std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); - std::vector hsaco(hsaco_file_size); hsaco_file.seekg(0, std::ios::beg); hsaco_file.read(reinterpret_cast(hsaco.data()), hsaco_file_size); hsaco_file.close(); + + // Check for register spilling using HSACO metadata + // Use amd_comgr library for fast in-process metadata extraction + VLOG(2) << "Checking for register spilling in: " + << module->getModuleIdentifier(); + + bool has_spilling = false; + int sgpr_spill_count = 0; + int vgpr_spill_count = 0; + int private_segment_size = 0; + + // Use already-loaded HSACO data for amd_comgr parsing + { + // Create amd_comgr data object from HSACO + amd_comgr_data_t comgr_data; + amd_comgr_status_t status = + amd_comgr_create_data(AMD_COMGR_DATA_KIND_EXECUTABLE, &comgr_data); + + if (status == AMD_COMGR_STATUS_SUCCESS) { + status = amd_comgr_set_data(comgr_data, hsaco.size(), + reinterpret_cast(hsaco.data())); + + if (status == AMD_COMGR_STATUS_SUCCESS) { + // Get metadata from the executable + amd_comgr_metadata_node_t metadata; + status = amd_comgr_get_data_metadata(comgr_data, &metadata); + + if (status == AMD_COMGR_STATUS_SUCCESS) { + // Helper lambda to lookup integer value from metadata map + auto lookup_int_value = [](amd_comgr_metadata_node_t root, + const char* key) -> int { + amd_comgr_metadata_node_t value_node; + amd_comgr_status_t s = + amd_comgr_metadata_lookup(root, key, &value_node); + if (s != AMD_COMGR_STATUS_SUCCESS) { + return 0; + } + + size_t size = 0; + s = amd_comgr_get_metadata_string(value_node, &size, nullptr); + if (s != AMD_COMGR_STATUS_SUCCESS || size == 0) { + amd_comgr_destroy_metadata(value_node); + return 0; + } + + std::string str_value(size, '\0'); + s = amd_comgr_get_metadata_string(value_node, &size, + str_value.data()); + amd_comgr_destroy_metadata(value_node); + + if (s != AMD_COMGR_STATUS_SUCCESS) { + return 0; + } + + // Parse the integer value + try { + return std::stoi(str_value); + } catch (...) { + return 0; + } + }; + + // Navigate to amdhsa.kernels array and check each kernel + amd_comgr_metadata_node_t kernels_node; + if (amd_comgr_metadata_lookup(metadata, "amdhsa.kernels", + &kernels_node) == + AMD_COMGR_STATUS_SUCCESS) { + size_t kernel_count = 0; + amd_comgr_get_metadata_list_size(kernels_node, &kernel_count); + + for (size_t i = 0; i < kernel_count; ++i) { + amd_comgr_metadata_node_t kernel_node; + if (amd_comgr_index_list_metadata(kernels_node, i, + &kernel_node) == + AMD_COMGR_STATUS_SUCCESS) { + // Get spill counts for this kernel + int kernel_sgpr_spill = + lookup_int_value(kernel_node, ".sgpr_spill_count"); + int kernel_vgpr_spill = + lookup_int_value(kernel_node, ".vgpr_spill_count"); + int kernel_private_size = lookup_int_value( + kernel_node, ".private_segment_fixed_size"); + + // Aggregate max values across all kernels + sgpr_spill_count = + std::max(sgpr_spill_count, kernel_sgpr_spill); + vgpr_spill_count = + std::max(vgpr_spill_count, kernel_vgpr_spill); + private_segment_size = + std::max(private_segment_size, kernel_private_size); + + amd_comgr_destroy_metadata(kernel_node); + } + } + amd_comgr_destroy_metadata(kernels_node); + } + + amd_comgr_destroy_metadata(metadata); + } else { + VLOG(2) << "Could not get HSACO metadata via amd_comgr"; + } + } + amd_comgr_release_data(comgr_data); + } else { + VLOG(2) << "Could not create amd_comgr data object"; + } + + if (sgpr_spill_count > 0 || vgpr_spill_count > 0 || + private_segment_size > 0) { + has_spilling = true; + } + } + + if (has_spilling) { + VLOG(0) << "====== REGISTER SPILLING DETECTED ======"; + VLOG(0) << "Module: " << module->getModuleIdentifier(); + VLOG(0) << "SGPR spill count: " << sgpr_spill_count; + VLOG(0) << "VGPR spill count: " << vgpr_spill_count; + VLOG(0) << "Private segment size: " << private_segment_size << " bytes"; + VLOG(0) << "Performance may be degraded due to register pressure"; + VLOG(0) << "========================================"; + + // Filter out kernels with register spilling during autotuning + // This matches NVIDIA's behavior in ptx_compiler_impl.cc + // TODO: remove ptx from xla_gpu_fail_ptx_compilation_on_register_spilling + // to make the flag more general + if (debug_options.xla_gpu_fail_ptx_compilation_on_register_spilling()) { + return xla::Cancelled( + "Compilation result discarded due to register spilling"); + } + } else { + VLOG(2) << "No register spilling detected"; + } + + // Clean up temp files if (!keep_tempfiles) { remove(ir_path.c_str()); remove(isabin_path.c_str()); @@ -562,6 +701,34 @@ std::vector GetAMDGPUBackendOptions( backend_extra_llvm_opts.cbegin(), backend_extra_llvm_opts.cend()); + // Manually add LLVM debug options for register usage analysis + // Note: The disassembly-based spilling detection is now the primary method. + // These options are mainly useful for debugging the compiler itself. + + // Uncomment if you want to see LLVM compilation details: + + // Option 1: Enable LLVM statistics (aggregate stats, not per-kernel) + // backend_llvm_opts.push_back("-stats"); + + // Option 2: Print final machine code (very verbose) + // backend_llvm_opts.push_back("-print-after-all"); + + // Option 3: Print after register allocation (shows register assignments) + // backend_llvm_opts.push_back("-print-after=regallocfast"); + // backend_llvm_opts.push_back("-print-after=regallocgreedy"); + + // Option 4: Enable pass timing (shows compilation time breakdown) + // backend_llvm_opts.push_back("-time-passes"); + + // Log the final LLVM options + if (!backend_llvm_opts.empty()) { + LOG(INFO) << "AMDGPU backend LLVM options (" << backend_llvm_opts.size() + << "):"; + for (const auto& opt : backend_llvm_opts) { + LOG(INFO) << " " << opt; + } + } + return backend_llvm_opts; } @@ -576,6 +743,10 @@ absl::StatusOr> CompileToHsaco( absl::call_once(backend_init_flag, AMDGPUBackendInit, debug_options, rocdl_dir_path); auto llvm_opts = GetAMDGPUBackendOptions(debug_options); + + VLOG(2) << "CompileToHsaco called for module: " + << module->getModuleIdentifier(); + llvm_ir::LLVMCommandLineOptionsLock llvm_lock(llvm_opts); std::vector hsaco; diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index a2494c5d7d6b60..dae6b73eba50a6 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -393,7 +393,7 @@ ENTRY main { "kind":"__triton", "block_level_fusion_config":{ "output_tiles":[{"sizes":["1","256000"]}], - "num_warps":"32", + "num_warps":"16", "num_ctas":"1", "num_stages":"1"}}} })", From 29c25e0d6bbcec9c3067054680e962dd4a0765ca Mon Sep 17 00:00:00 2001 From: Marcin Radomski Date: Tue, 9 Dec 2025 05:36:53 -0800 Subject: [PATCH 077/753] [XLA] Use absl::StrCat instead of strings::StrCat PiperOrigin-RevId: 842193625 --- third_party/xla/xla/tsl/platform/cloud/BUILD | 1 - .../platform/cloud/gcs_file_system_test.cc | 36 +++++++++---------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/third_party/xla/xla/tsl/platform/cloud/BUILD b/third_party/xla/xla/tsl/platform/cloud/BUILD index 4fbc7b0633da6f..cd3ca465817554 100644 --- a/third_party/xla/xla/tsl/platform/cloud/BUILD +++ b/third_party/xla/xla/tsl/platform/cloud/BUILD @@ -431,7 +431,6 @@ tsl_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/platform:strcat", ], ) diff --git a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc index 58db30d178d149..646ee9f1a8a68b 100644 --- a/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/xla/xla/tsl/platform/cloud/gcs_file_system_test.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/types.h" #include "tsl/platform/retrying_utils.h" -#include "tsl/platform/strcat.h" // Undef DeleteFile macro defined in wndows.h. #ifdef PLATFORM_WINDOWS @@ -1497,9 +1496,9 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { "path%2Frandom_access.txt?fields=size%2Cgeneration%2Cupdated\n" "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", - strings::StrCat("{\"size\": \"", content.size(), "\"", - ", \"generation\": \"1\"", - ", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), + absl::StrCat("{\"size\": \"", content.size(), "\"", + ", \"generation\": \"1\"", + ", \"updated\": \"2016-04-29T23:15:24.896Z\"}")), new FakeHttpRequest( absl::StrCat("Uri: https://storage.googleapis.com/bucket/" "path%2Frandom_access.txt\n" @@ -4383,12 +4382,12 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { "location"}}), // Uploads entire file again. new FakeHttpRequest( - strings::StrCat("Uri: https://custom/upload/location\n" - "Auth Token: fake_token\n" - "Header Content-Range: bytes 0-26/27\n" - "Timeouts: 5 1 30\n" - "Put body: ", - contents[0], contents[1], contents[2], "\n"), + absl::StrCat("Uri: https://custom/upload/location\n" + "Auth Token: fake_token\n" + "Header Content-Range: bytes 0-26/27\n" + "Timeouts: 5 1 30\n" + "Put body: ", + contents[0], contents[1], contents[2], "\n"), ""), new FakeHttpRequest( "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" @@ -4399,15 +4398,14 @@ TEST(GcsFileSystemTest, NewAppendableFile_MultipleFlushesWithoutCompose) { "Timeouts: 5 1 10\n", "", {{"Location", "https://custom/upload/location"}}), // Uploads entire file again. - new FakeHttpRequest( - strings::StrCat("Uri: https://custom/upload/location\n" - "Auth Token: fake_token\n" - "Header Content-Range: bytes 0-35/36\n" - "Timeouts: 5 1 30\n" - "Put body: ", - contents[0], contents[1], contents[2], contents[3], - "\n"), - ""), + new FakeHttpRequest(absl::StrCat("Uri: https://custom/upload/location\n" + "Auth Token: fake_token\n" + "Header Content-Range: bytes 0-35/36\n" + "Timeouts: 5 1 30\n" + "Put body: ", + contents[0], contents[1], contents[2], + contents[3], "\n"), + ""), }); GcsFileSystem fs( std::unique_ptr(new FakeAuthProvider), From e3a3c04334fc280d096246f12e074404eb419376 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 06:29:51 -0800 Subject: [PATCH 078/753] Automated Code Change PiperOrigin-RevId: 842209201 --- third_party/xla/xla/tsl/concurrency/async_value_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/tsl/concurrency/async_value_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_test.cc index 57f968d6824057..75005391c3ef4b 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_test.cc @@ -178,7 +178,8 @@ TEST(AsyncValueTest, StackAllocatedAsyncValue) { EXPECT_TRUE(ptr.IsAvailable()); // Check that when owner is destructed it calls the payload destructor. - std::make_unique>(std::move(owner)); + static_cast( + std::make_unique>(std::move(owner))); EXPECT_EQ(2, counter); } From 80246cf0975f03be9b5f315fcfebe7029d8ff207 Mon Sep 17 00:00:00 2001 From: Dirk Hornung Date: Tue, 9 Dec 2025 07:57:12 -0800 Subject: [PATCH 079/753] [Autotuner] Log the per backend supported config count. PiperOrigin-RevId: 842238520 --- third_party/xla/xla/backends/autotuner/autotuner.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/third_party/xla/xla/backends/autotuner/autotuner.cc b/third_party/xla/xla/backends/autotuner/autotuner.cc index 2c2dcb14750ef7..2d98a517270eb0 100644 --- a/third_party/xla/xla/backends/autotuner/autotuner.cc +++ b/third_party/xla/xla/backends/autotuner/autotuner.cc @@ -314,7 +314,8 @@ absl::StatusOr Autotuner::TuneBestConfig( absl::StrCat("Autotuner could not find any supported configs for HLO: ", instr->ToString())); } - VLOG(1) << "Found " << supported_configs.size() << " supported configs."; + VLOG(1) << "Found total of " << supported_configs.size() + << " supported configs."; std::vector>> executables = CompileAll(instr, supported_configs); @@ -411,8 +412,13 @@ absl::StatusOr> Autotuner::GetSupportedConfigs( absl::StatusOr>> per_backend_configs = codegen_backend->GetSupportedConfigs(*instr); if (!per_backend_configs.ok()) { + VLOG(3) << "Failed to get supported configs for backend " + << codegen_backend->name() << ": " + << per_backend_configs.status(); continue; } + VLOG(3) << "Found of " << per_backend_configs->size() + << " supported configs for backend " << codegen_backend->name(); for (auto& config : *per_backend_configs) { configs.push_back({codegen_backend.get(), std::move(config)}); } From 6f72793d606a8d8d680d9c286b198b6c533b5ade Mon Sep 17 00:00:00 2001 From: Theotime Combes Date: Tue, 9 Dec 2025 08:28:35 -0800 Subject: [PATCH 080/753] [XLA:GPU]Extract transpose normalization logic to utils Pre-requisite to performing normalization OTF and remove the pass No-op in terms of behavior PiperOrigin-RevId: 842250914 --- third_party/xla/xla/BUILD | 1 + .../transforms/transpose_dimension_grouper.cc | 121 +---------------- third_party/xla/xla/shape_util.cc | 95 ++++++++++++++ third_party/xla/xla/shape_util.h | 32 +++++ third_party/xla/xla/shape_util_test.cc | 124 ++++++++++++++++++ 5 files changed, 254 insertions(+), 119 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index d07d329e7c2dca..e653fc2cc8b7e3 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -565,6 +565,7 @@ xla_cc_test( "//xla/tsl/platform:test_benchmark", "//xla/tsl/platform:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc index d6dfc08863b5a2..26d9a8f87049c0 100644 --- a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_set.h" @@ -42,126 +41,10 @@ namespace xla { namespace gpu { namespace { -// Returns the indices of the first elements of all consecutive subarrays of the -// given array. For example: -// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -absl::InlinedVector ConsecutiveSegments( - absl::Span xs) { - absl::InlinedVector is = {0}; - for (size_t i = 1; i < xs.size(); ++i) { - if (1 != xs[i] - xs[i - 1]) { - is.push_back(i); - } - } - return is; -} - -// Merges the sequences of dimensions of the given shape which start at the -// given indices `segs`. -Shape MergeDimensions(absl::Span segs, const Shape &shape) { - std::vector dimensions; - const auto size = segs.size(); - dimensions.reserve(size); - for (size_t i = 1; i <= size; ++i) { - dimensions.push_back(std::accumulate( - shape.dimensions().begin() + segs[i - 1], - shape.dimensions().begin() + - (segs.size() == i ? shape.dimensions().size() : segs[i]), - int64_t{1}, std::multiplies())); - } - return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), - dimensions); -} - -absl::InlinedVector GetNormalizedTransposeShapeHelper( - const Shape &output_shape, absl::Span output_to_input, - absl::InlinedVector &permutation) { - absl::InlinedVector segments = - ConsecutiveSegments(output_to_input); - Shape normalized_shape = MergeDimensions(segments, output_shape); - absl::InlinedVector normalized_dims( - normalized_shape.dimensions().begin(), - normalized_shape.dimensions().end()); - if (segments.size() == 1) { - return normalized_dims; - } - // Derive the permutation from the segments. - std::vector segment_to_normalized_dim( - output_shape.dimensions().size(), -1); - for (size_t segment : segments) { - segment_to_normalized_dim[output_to_input[segment]] = 0; - } - int64_t normalized_dim = 0; - for (int64_t i = 0; i < segment_to_normalized_dim.size(); ++i) { - if (segment_to_normalized_dim[i] >= 0) { - segment_to_normalized_dim[i] = normalized_dim++; - } - } - permutation.reserve(segments.size()); - for (int64_t i = 0; i < segments.size(); ++i) { - permutation.push_back( - segment_to_normalized_dim[output_to_input[segments[i]]]); - } - return normalized_dims; -} - -// In this case, we care about transposes that permute dimensions of a shape -// that can be viewed as several logical components in the order of major to -// minor. As an example, let's consider a 0-2-1 transpose: -// -// If a shape can be viewed as three logical components 0-1-2 in the order of -// major to minor, a 0-2-1-transpose changes the order of such logical -// components to 0-2-1. We call the shape being transposed the input shape and -// the transposed shape the output shape. The logical view of the input/output -// shapes for the transpose are called the 0-1-2/0-2-1 shapes or the normalized -// shapes. The original input/output shapes are called unnormalized shapes. -// -// 'output_shape' should have the default layout (enforced by the caller). -// -// 'dimensions' specifies the kind of the unnormalized transpose and defines the -// permutation of the input shape that will result in the provided output shape. -// So to compute the input shape, we need to apply the inverse permutation of -// 'dimensions'. -// -// 'permutation' is an output parameter and specifies the kind of the normalized -// transpose. -// -// The method returns the dimensions for the normalized transpose shape. -// -// Example: Suppose the unnormalized output shape is [32, 1, 10, 11], and -// 'dimensions' is set to {3, 1, 0, 2}. This means the corresponding input shape -// is [10, 1, 11, 32]. The normalized output shape is [32, 110] with -// 'permutation' set to {1,0}. -absl::InlinedVector GetNormalizedLogicalTransposeShape( - const Shape &output_shape, absl::Span dimensions, - absl::InlinedVector &permutation) { - permutation.clear(); - // Drop degenerate dimensions. - absl::InlinedVector delta(output_shape.dimensions().size() + 1, - 0); - auto input_dimensions = - Permute(output_shape.dimensions(), InversePermutation(dimensions)); - for (int i = 0; i < output_shape.dimensions().size(); ++i) { - delta[i + 1] = delta[i]; - if (input_dimensions[i] == static_cast(1)) { - ++delta[i + 1]; - } - } - absl::InlinedVector new_dimensions; - for (int i = 0; i < dimensions.size(); i++) { - if (output_shape.dimensions(i) != 1) { - new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]); - } - } - - return GetNormalizedTransposeShapeHelper( - ShapeUtil::DropDegenerateDimensions(output_shape), new_dimensions, - permutation); -} class TransposeDimensionGroupVisitor : public DfsHloRewriteVisitor { public: - absl::Status HandleTranspose(HloInstruction *transpose) override { + absl::Status HandleTranspose(HloInstruction* transpose) override { VLOG(4) << "Input: " << transpose->ToString(); if (!LayoutUtil::IsMonotonicWithDim0Major(transpose->shape().layout()) || !LayoutUtil::IsMonotonicWithDim0Major( @@ -174,7 +57,7 @@ class TransposeDimensionGroupVisitor : public DfsHloRewriteVisitor { "transpose and its operand"); } absl::InlinedVector permutation; - auto normalized_dims = GetNormalizedLogicalTransposeShape( + auto normalized_dims = ShapeUtil::GetNormalizedLogicalTransposeShape( transpose->shape(), transpose->dimensions(), permutation); if (normalized_dims.size() == 1 || normalized_dims == transpose->shape().dimensions()) { diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 321f9dfffab4cd..1fbea10079413b 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -2336,6 +2336,101 @@ int64_t ShapeUtil::ForEachState::CalculateNumSteps() const { }); } +namespace { + +// Returns the indices of the first elements of all consecutive subarrays of the +// given array. For example: +// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} +absl::InlinedVector ConsecutiveSegments( + absl::Span xs) { + absl::InlinedVector is = {0}; + for (size_t i = 1; i < xs.size(); ++i) { + if (1 != xs[i] - xs[i - 1]) { + is.push_back(i); + } + } + return is; +} + +// Merges the sequences of dimensions of the given shape which start at the +// given indices `segs`. +Shape MergeDimensions(absl::Span segs, const Shape& shape) { + std::vector dimensions; + const auto size = segs.size(); + dimensions.reserve(size); + for (size_t i = 1; i <= size; ++i) { + dimensions.push_back(std::accumulate( + shape.dimensions().begin() + segs[i - 1], + shape.dimensions().begin() + + (segs.size() == i ? shape.dimensions().size() : segs[i]), + int64_t{1}, std::multiplies())); + } + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); +} + +absl::InlinedVector GetNormalizedTransposeShapeHelper( + const Shape& output_shape, absl::Span output_to_input, + absl::InlinedVector& permutation) { + absl::InlinedVector segments = + ConsecutiveSegments(output_to_input); + Shape normalized_shape = MergeDimensions(segments, output_shape); + absl::InlinedVector normalized_dims( + normalized_shape.dimensions().begin(), + normalized_shape.dimensions().end()); + if (segments.size() == 1) { + return normalized_dims; + } + // Derive the permutation from the segments. + std::vector segment_to_normalized_dim( + output_shape.dimensions().size(), -1); + for (size_t segment : segments) { + segment_to_normalized_dim[output_to_input[segment]] = 0; + } + int64_t normalized_dim = 0; + for (int64_t i = 0; i < segment_to_normalized_dim.size(); ++i) { + if (segment_to_normalized_dim[i] >= 0) { + segment_to_normalized_dim[i] = normalized_dim++; + } + } + permutation.reserve(segments.size()); + for (int64_t i = 0; i < segments.size(); ++i) { + permutation.push_back( + segment_to_normalized_dim[output_to_input[segments[i]]]); + } + return normalized_dims; +} + +} // namespace + +/*static*/ absl::InlinedVector +ShapeUtil::GetNormalizedLogicalTransposeShape( + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation) { + permutation.clear(); + // Drop degenerate dimensions. + absl::InlinedVector delta(output_shape.dimensions().size() + 1, + 0); + auto input_dimensions = + Permute(output_shape.dimensions(), InversePermutation(dimensions)); + for (int i = 0; i < output_shape.dimensions().size(); ++i) { + delta[i + 1] = delta[i]; + if (input_dimensions[i] == static_cast(1)) { + ++delta[i + 1]; + } + } + absl::InlinedVector new_dimensions; + for (int i = 0; i < dimensions.size(); i++) { + if (output_shape.dimensions(i) != 1) { + new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]); + } + } + + return GetNormalizedTransposeShapeHelper( + ShapeUtil::DropDegenerateDimensions(output_shape), new_dimensions, + permutation); +} + /*static*/ void ShapeUtil::FlattenTupleShape( const Shape& shape, std::vector& flattened) { if (shape.IsTuple()) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index fde70d0dd22ef5..12cd8e59bd58c7 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -435,6 +435,38 @@ class ShapeUtil { static bool IsEffectivelyMostMajorDimension(const Shape& shape, int64_t dimension); + // In this case, we care about transposes that permute dimensions of a shape + // that can be viewed as several logical components in the order of major to + // minor. As an example, let's consider a 0-2-1 transpose: + // + // If a shape can be viewed as three logical components 0-1-2 in the order of + // major to minor, a 0-2-1-transpose changes the order of such logical + // components to 0-2-1. We call the shape being transposed the input shape and + // the transposed shape the output shape. The logical view of the input/output + // shapes for the transpose are called the 0-1-2/0-2-1 shapes or the + // normalized shapes. The original input/output shapes are called unnormalized + // shapes. + // + // 'output_shape' should have the default layout (enforced by the caller). + // + // 'dimensions' specifies the kind of the unnormalized transpose and defines + // the permutation of the input shape that will result in the provided output + // shape. So to compute the input shape, we need to apply the inverse + // permutation of 'dimensions'. + // + // 'permutation' is an output parameter and specifies the kind of the + // normalized transpose. + // + // The method returns the dimensions for the normalized transpose shape. + // + // Example: Suppose the unnormalized output shape is [32, 1, 10, 11], and + // 'dimensions' is set to {3, 1, 0, 2}. This means the corresponding input + // shape is [10, 1, 11, 32]. The normalized output shape is [32, 110] with + // 'permutation' set to {1,0}. + static absl::InlinedVector GetNormalizedLogicalTransposeShape( + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation); + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return Shape(std::vector{}); } diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index e1015ca2dc7778..265e36b839b289 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -45,6 +47,7 @@ namespace xla { namespace { using ::testing::ElementsAre; +using ::testing::IsEmpty; TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) { Shape matrix = ShapeUtil::MakeShape(F32, {2, 3}); @@ -1776,5 +1779,126 @@ void BM_ForEachIndexNoStatus(::testing::benchmark::State& state) { BENCHMARK(BM_ForEachIndexNoStatus)->Arg(0)->Arg(1)->Arg(2); +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape) { + Shape output_shape = ShapeUtil::MakeShape(F32, {32, 1, 10, 11}); + absl::InlinedVector dimensions = {3, 1, 0, 2}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(32, 110)); + EXPECT_THAT(permutation, ElementsAre(1, 0)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape2) { + Shape output_shape = ShapeUtil::MakeShape(F32, {20, 30, 50}); + absl::InlinedVector dimensions = {1, 2, 0}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(600, 50)); + EXPECT_THAT(permutation, ElementsAre(1, 0)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_NoTranspose) { + Shape output_shape = ShapeUtil::MakeShape(F32, {64, 1, 128}); + absl::InlinedVector dimensions = {0, 2, 1}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(8192)); + EXPECT_THAT(permutation, IsEmpty()); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_Simple2D) { + Shape output_shape = ShapeUtil::MakeShape(F32, {64, 128}); + absl::InlinedVector dimensions = {1, 0}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(64, 128)); + EXPECT_THAT(permutation, ElementsAre(1, 0)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_Simple3D_021) { + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 32768}); + absl::InlinedVector dimensions = {0, 2, 1}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(8, 16, 32768)); + EXPECT_THAT(permutation, ElementsAre(0, 2, 1)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_Simple3D_210) { + Shape output_shape = ShapeUtil::MakeShape(F32, {16, 32768, 8}); + absl::InlinedVector dimensions = {2, 1, 0}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(16, 32768, 8)); + EXPECT_THAT(permutation, ElementsAre(2, 1, 0)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_Simple4D) { + Shape output_shape = ShapeUtil::MakeShape(F32, {16, 32768, 8, 4}); + absl::InlinedVector dimensions = {2, 0, 3, 1}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(16, 32768, 8, 4)); + EXPECT_THAT(permutation, ElementsAre(2, 0, 3, 1)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_NormalizeTo3D) { + Shape output_shape = ShapeUtil::MakeShape(F32, {8, 16, 32, 32, 32}); + absl::InlinedVector dimensions = {0, 4, 1, 2, 3}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(8, 16, 32768)); + EXPECT_THAT(permutation, ElementsAre(0, 2, 1)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_LargeShapeSizeOverflow) { + Shape output_shape = ShapeUtil::MakeShape(F32, {16, 4096, 4096, 128}); + absl::InlinedVector dimensions = {3, 0, 1, 2}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(16, 2147483648)); + EXPECT_THAT(permutation, ElementsAre(1, 0)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_DegenerateDims) { + Shape output_shape = ShapeUtil::MakeShape(F32, {1, 32, 1, 64, 1, 3, 1}); + absl::InlinedVector dimensions = {6, 1, 4, 5, 2, 3, 0}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(32, 64, 3)); + EXPECT_THAT(permutation, ElementsAre(0, 2, 1)); +} + +TEST(ShapeUtilTest, GetNormalizedLogicalTransposeShape_TransposeWithGrouping) { + Shape output_shape = ShapeUtil::MakeShape(F32, {10, 1, 32, 100, 2}); + absl::InlinedVector dimensions = {2, 1, 3, 0, 4}; + absl::InlinedVector permutation; + auto normalized_shape = ShapeUtil::GetNormalizedLogicalTransposeShape( + output_shape, dimensions, permutation); + + EXPECT_THAT(normalized_shape, ElementsAre(320, 100, 2)); + EXPECT_THAT(permutation, ElementsAre(1, 0, 2)); +} + } // namespace } // namespace xla From 37b13b82702bc818c1f0c83e9b45515890eed9d1 Mon Sep 17 00:00:00 2001 From: Kanish Anand Date: Tue, 9 Dec 2025 08:33:30 -0800 Subject: [PATCH 081/753] (2/N) Add support for `NamedSharding` in existing `HloShardingUtil` methods. Remaining methods will be updated in follow up cl's. PiperOrigin-RevId: 842252683 --- third_party/xla/xla/hlo/ir/hlo_sharding.h | 27 +++---- third_party/xla/xla/hlo/ir/named_sharding.h | 10 ++- third_party/xla/xla/hlo/utils/BUILD | 4 ++ .../xla/xla/hlo/utils/hlo_sharding_util.cc | 70 +++++++++++++++++++ .../xla/xla/hlo/utils/hlo_sharding_util.h | 26 +++++-- .../xla/hlo/utils/hlo_sharding_util_test.cc | 55 +++++++++++++++ .../xla/service/spmd/spmd_partitioner_util.cc | 2 +- 7 files changed, 171 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.h b/third_party/xla/xla/hlo/ir/hlo_sharding.h index d4a0515e931146..488dfdb2793421 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.h @@ -130,16 +130,6 @@ class HloSharding { metadata); } - explicit HloSharding(NamedSharding named_sharding) - : replicated_(false), - maximal_(false), - tuple_(false), - manual_(false), - unknown_(false), - unreduced_(false), - replicate_on_last_tile_dim_(false), - named_sharding_(std::move(named_sharding)) {} - // Creates a subgroup sharding with device-level tile assignment, the // sharding type of each subgroup is defined by subgroup_types. When creating // the HloSharding, subgroup dims of the same type will be merged. @@ -493,6 +483,11 @@ class HloSharding { // REQUIRES: !IsReplicated() && !IsTuple() const TileAssignment& tile_assignment() const { return tile_assignment_; } + const NamedSharding& named_sharding() const { + CHECK(UseNamedShardingLeaf()); + return named_sharding_.value(); + } + // Returns the number of dimensions. int64_t num_dimensions() const { return tile_assignment().num_dimensions(); } @@ -668,9 +663,15 @@ class HloSharding { const ShardGroup& GetShardGroup() const { return shard_group_; } - std::optional named_sharding() const { - return named_sharding_; - } + explicit HloSharding(NamedSharding named_sharding) + : replicated_(false), + maximal_(false), + tuple_(false), + manual_(false), + unknown_(false), + unreduced_(false), + replicate_on_last_tile_dim_(false), + named_sharding_(std::move(named_sharding)) {} private: explicit HloSharding(bool manual, bool replicated, bool unknown, diff --git a/third_party/xla/xla/hlo/ir/named_sharding.h b/third_party/xla/xla/hlo/ir/named_sharding.h index bfdc9966c0b15d..01ab052d24a22b 100644 --- a/third_party/xla/xla/hlo/ir/named_sharding.h +++ b/third_party/xla/xla/hlo/ir/named_sharding.h @@ -64,8 +64,6 @@ class NamedSharding { return !(*this == other); } - const Mesh& mesh() const { return mesh_; } - // TODO(b/456212087): Add validation checks explicit NamedSharding(Mesh mesh, absl::Span dim_shardings = {}, @@ -78,6 +76,14 @@ class NamedSharding { unreduced_axes_(unreduced_axes.begin(), unreduced_axes.end()), metadata_(metadata.begin(), metadata.end()) {} + const Mesh& mesh() const { return mesh_; } + absl::Span dim_shardings() const { + return dim_shardings_; + } + absl::Span replicated_axes() const { return replicated_axes_; } + absl::Span unreduced_axes() const { return unreduced_axes_; } + absl::Span metadata() const { return metadata_; } + private: friend class HloSharding; diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 534fdbcafb1d58..66f7bf731b8dad 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -151,6 +151,8 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:mesh_and_axis", + "//xla/hlo/ir:named_sharding", "//xla/hlo/ir:tile_assignment", "//xla/service:call_graph", "//xla/service:dot_as_convolution_util", @@ -184,6 +186,8 @@ xla_cc_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:mesh_and_axis", + "//xla/hlo/ir:named_sharding", "//xla/hlo/ir:tile_assignment", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:test", diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 1644a384510159..c8ac759cf20973 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -48,6 +48,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/mesh_and_axis.h" +#include "xla/hlo/ir/named_sharding.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/utils/hlo_container_util.h" #include "xla/layout.h" @@ -1080,6 +1082,21 @@ HloSharding PropagateShardingAlongDimsAndReplicateOthers( return source_sharding; } + if (source_sharding.UseNamedShardingLeaf()) { + std::vector target_dim_shardings( + target_shape_rank); + for (int i = 0; i < source_dims.size(); ++i) { + target_dim_shardings[target_dims[i]] = + source_sharding.named_sharding().dim_shardings()[source_dims[i]]; + } + + return HloSharding(NamedSharding( + source_sharding.named_sharding().mesh(), target_dim_shardings, + source_sharding.named_sharding().replicated_axes(), + source_sharding.named_sharding().unreduced_axes(), + source_sharding.named_sharding().metadata())); + } + HloSharding replicate_other_dims = PartiallyReplicateTiledShardingOnAllDimsExcept(source_sharding, source_dims); @@ -1493,6 +1510,22 @@ HloSharding PartiallyReplicateTiledShardingOnDims( if (sharding.IsTileMaximal() || sharding.IsManual()) { return sharding; } + + if (sharding.UseNamedShardingLeaf()) { + std::vector dim_shardings( + sharding.named_sharding().dim_shardings().begin(), + sharding.named_sharding().dim_shardings().end()); + for (int64_t dim : dims_to_replicate) { + if (dim < dim_shardings.size()) { + dim_shardings[dim] = NamedSharding::DimensionSharding(); + } + } + return HloSharding(NamedSharding( + sharding.named_sharding().mesh(), dim_shardings, + sharding.named_sharding().replicated_axes(), + sharding.named_sharding().unreduced_axes(), sharding.metadata())); + } + int64_t group_count = 1; DimensionVector valid_dims_to_replicate; for (int64_t dim : dims_to_replicate) { @@ -1555,6 +1588,15 @@ HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept( HloSharding ReplicateAllDataDims(const HloSharding& sharding, int64_t data_rank) { + if (sharding.UseNamedShardingLeaf()) { + std::vector dim_shardings( + data_rank >= 0 ? data_rank : sharding.num_dimensions()); + return HloSharding(NamedSharding( + sharding.named_sharding().mesh(), dim_shardings, + sharding.named_sharding().replicated_axes(), + sharding.named_sharding().unreduced_axes(), sharding.metadata())); + } + if (sharding.IsManual()) { return sharding; } @@ -1580,6 +1622,34 @@ HloSharding RemoveShapeDimensions(const HloSharding& sharding, if (sharding.IsTileMaximal() || dims_to_remove.empty()) { return sharding; } + + if (sharding.UseNamedShardingLeaf()) { + // Check to ensure subgroup dimensions are not passed in dims_to_remove as + // named sharding doesn't handle them as part of dim_shardings but separate + // replicated, unreduced axes as opposed to tile hlo sharding format which + // uses tile dimensions to represent subgroup dimensions as well. + DCHECK( + std::all_of(dims_to_remove.begin(), dims_to_remove.end(), + [&](int64_t i) { return i < sharding.num_dimensions(); })); + + std::vector new_dim_shardings; + new_dim_shardings.reserve(sharding.num_dimensions() - + dims_to_remove.size()); + for (int64_t i = 0; i < sharding.num_dimensions(); ++i) { + if (absl::c_linear_search(dims_to_remove, i)) { + CHECK_EQ(sharding.dimension(i), 1); + } else { + new_dim_shardings.push_back( + sharding.named_sharding().dim_shardings()[i]); + } + } + + return HloSharding(NamedSharding( + sharding.named_sharding().mesh(), new_dim_shardings, + sharding.named_sharding().replicated_axes(), + sharding.named_sharding().unreduced_axes(), sharding.metadata())); + } + DimensionVector new_tile_shape; new_tile_shape.reserve(sharding.num_dimensions() - dims_to_remove.size()); for (int64_t i = 0; i < sharding.num_dimensions(); ++i) { diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index c5164f0be6e26f..1f521eedaa8006 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -251,9 +251,10 @@ HloSharding PartiallyReplicateTiledShardingOnAllDimsExcept( HloSharding ReplicateAllDataDims(const HloSharding& sharding, int64_t data_rank = -1); -// Returns a sharding the removes given tile dimensions. +// Returns a sharding that removes given sharding dimensions. // -// Precondition: if not tile maximal, the size of each tile dimension must be 1. +// Precondition: if not tile maximal, the size of each sharding dimension must +// be 1. HloSharding RemoveShapeDimensions(const HloSharding& sharding, absl::Span dims_to_remove); @@ -264,12 +265,13 @@ std::optional TransposeShardingWithCollapsedDims( const HloSharding& source, absl::Span src_to_tgt, absl::Span tgt_to_src); -// Given a `source_sharding`, preserve the tiles along the `source_dims` and -// replicate the rest. The `target_dims` are used to determine the order of the -// dimensions in the resulting sharding. If `source_dims` and `target_dims` are -// in the different order (i.e., different ArgSort results), we need to -// transpose the tile assignment. +// Given a `source_sharding`, preserve the dimensions along the `source_dims` +// and replicate the rest. The `target_dims` are used to determine the order of +// the dimensions in the resulting sharding. // +// [For tiled sharding format] If `source_dims` and `target_dims` are in the +// different order (i.e., different ArgSort results), we need to transpose the +// tile assignment. // Given the following input, // * source_sharding = {devices=[2,3,5,7,11]<=[2310]} // * source_dims = [2, 4, 1] @@ -277,6 +279,16 @@ std::optional TransposeShardingWithCollapsedDims( // * target_shape_rank = 5 // The result shoule be {devices=[1,11,5,3,1,14]<=[2,3,5,7,11]T(4,2,1,0,3) // last_tile_dim_replicate}. +// +// [For named sharding format] +// Given the following input, +// * mesh = Mesh({2, 3, 5, 7, 11}, {"a", "b", "c", "d", "e"}); +// * source_sharding = NamedSharding(mesh, {{"a"}, {"b"}, {"c"}, {"d"}, +// {"e"}}) +// * source_dims = [2, 4, 1] +// * target_dims = [2, 1, 3] +// * target_shape_rank = 5 +// The result shoule be NamedSharding(mesh, {{}, {"e"}, {"c"}, {"b"}, {}}) HloSharding PropagateShardingAlongDimsAndReplicateOthers( const HloSharding& source_sharding, absl::Span source_dims, absl::Span target_dims, int64_t target_shape_rank); diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index ecaad635b7a440..ab7a203e0d2ae1 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/ir/mesh_and_axis.h" +#include "xla/hlo/ir/named_sharding.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/test.h" @@ -566,6 +568,18 @@ TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers1) { HloSharding expected = HloSharding::PartialTile( TileAssignment({1, 11, 5, 3, 1, 14}, {2, 3, 5, 7, 11}, {4, 2, 1, 0, 3})); EXPECT_EQ(target_sharding, expected); + + { + Mesh mesh({2, 3, 5, 7, 11}, {"a", "b", "c", "d", "e"}); + NamedSharding source_sharding = + test_utils::FromAxisNames(mesh, {{"a"}, {"b"}, {"c"}, {"d"}, {"e"}}); + HloSharding target_sharding = PropagateShardingAlongDimsAndReplicateOthers( + HloSharding(source_sharding), source_dims, target_dims, + target_shape_rank); + NamedSharding expected = + test_utils::FromAxisNames(mesh, {{}, {"e"}, {"c"}, {"b"}, {}}); + EXPECT_EQ(target_sharding.named_sharding(), expected); + } } TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers2) { @@ -578,6 +592,18 @@ TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers2) { HloSharding expected = HloSharding::PartialTile( TileAssignment({2, 5, 11, 21}, {2, 3, 5, 7, 11}, {0, 2, 4, 1, 3})); EXPECT_EQ(target_sharding, expected); + + { + Mesh mesh({2, 3, 5, 7, 11}, {"a", "b", "c", "d", "e"}); + NamedSharding source_sharding = + test_utils::FromAxisNames(mesh, {{"a"}, {"b"}, {"c"}, {"d"}, {"e"}}); + HloSharding target_sharding = PropagateShardingAlongDimsAndReplicateOthers( + HloSharding(source_sharding), source_dims, target_dims, + target_shape_rank); + NamedSharding expected = + test_utils::FromAxisNames(mesh, {{"a"}, {"c"}, {"e"}}); + EXPECT_EQ(target_sharding.named_sharding(), expected); + } } TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers3) { @@ -590,6 +616,35 @@ TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers3) { HloSharding expected = HloSharding::PartialTile( TileAssignment({11, 7, 1, 3, 10}, {2, 3, 5, 7, 11}, {4, 3, 1, 0, 2})); EXPECT_EQ(target_sharding, expected); + + { + Mesh mesh({2, 3, 5, 7, 11}, {"a", "b", "c", "d", "e"}); + NamedSharding source_sharding = + test_utils::FromAxisNames(mesh, {{"a"}, {"b"}, {"c"}, {"d"}, {"e"}}); + HloSharding target_sharding = PropagateShardingAlongDimsAndReplicateOthers( + HloSharding(source_sharding), source_dims, target_dims, + target_shape_rank); + NamedSharding expected = + test_utils::FromAxisNames(mesh, {{"e"}, {"d"}, {}, {"b"}}); + EXPECT_EQ(target_sharding.named_sharding(), expected); + } +} + +TEST(HloShardingUtilTest, PropagateShardingAlongDimsAndReplicateOthers4) { + Mesh mesh({2, 3, 5, 7, 11}, {"a", "b", "c", "d", "e"}); + NamedSharding source_sharding = + test_utils::FromAxisNames(mesh, {{"a"}, {"c", "b"}, {}, {"d"}, {}}, {}, + /*unreduced_axes=*/{"e"}); + std::vector source_dims = {2, 1, 3}; + std::vector target_dims = {0, 3, 1}; + int64_t target_shape_rank = 4; + HloSharding target_sharding = PropagateShardingAlongDimsAndReplicateOthers( + HloSharding(source_sharding), source_dims, target_dims, + target_shape_rank); + NamedSharding expected = + test_utils::FromAxisNames(mesh, {{}, {"d"}, {}, {"c", "b"}}, {}, + /*unreduced_axes=*/{"e"}); + EXPECT_EQ(target_sharding.named_sharding(), expected); } TEST(HloShardingUtilTest, MergeManualSubgroupSharding) { diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index d87ef705b8185c..8e79e7c16d2e84 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -3053,7 +3053,7 @@ std::optional GetIotaPartitionGroupsForReplication( std::optional GetMeshFromSharding(const HloSharding& sharding) { // For V3 shardings, use the mesh associated with the named sharding. if (sharding.UseNamedShardingLeaf()) { - return sharding.named_sharding()->mesh(); + return sharding.named_sharding().mesh(); } // For V2 shardings, create the mesh from the tile assignment. From 2009f3930d65c32cf010cd685ce9a988b831ef67 Mon Sep 17 00:00:00 2001 From: Will Froom Date: Tue, 9 Dec 2025 08:33:49 -0800 Subject: [PATCH 082/753] [XLA:CPU] Add missing vectorization sizes from tanh and exp approximation. PiperOrigin-RevId: 842252816 --- .../cpu/codegen/polynomial_approximations.cc | 7 ++++--- third_party/xla/xla/codegen/intrinsic/tanh.h | 13 +++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc index 0c8084568e41c5..947566d3de8715 100644 --- a/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc +++ b/third_party/xla/xla/backends/cpu/codegen/polynomial_approximations.cc @@ -521,12 +521,13 @@ void RewriteToPolynomialApproximations(llvm::Module* module, rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1); rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1); - rewrite_calls(kExpV4F32Sym, GenerateVF32Exp, /*vector_width=*/4); + rewrite_calls("llvm.exp.v2f32", GenerateVF32Exp, /*vector_width=*/2); rewrite_calls("llvm.exp.v4f32", GenerateVF32Exp, /*vector_width=*/4); - rewrite_calls(kExpV8F32Sym, GenerateVF32Exp, /*vector_width=*/8); rewrite_calls("llvm.exp.v8f32", GenerateVF32Exp, /*vector_width=*/8); - rewrite_calls(kExpV16F32Sym, GenerateVF32Exp, /*vector_width=*/16); rewrite_calls("llvm.exp.v16f32", GenerateVF32Exp, /*vector_width=*/16); + rewrite_calls(kExpV4F32Sym, GenerateVF32Exp, /*vector_width=*/4); + rewrite_calls(kExpV8F32Sym, GenerateVF32Exp, /*vector_width=*/8); + rewrite_calls(kExpV16F32Sym, GenerateVF32Exp, /*vector_width=*/16); rewrite_calls("llvm.exp.f16", UpcastF16ToF32, /*vector_width=*/1); diff --git a/third_party/xla/xla/codegen/intrinsic/tanh.h b/third_party/xla/xla/codegen/intrinsic/tanh.h index 34d60229c29026..022a09951f3d3a 100644 --- a/third_party/xla/xla/codegen/intrinsic/tanh.h +++ b/third_party/xla/xla/codegen/intrinsic/tanh.h @@ -33,12 +33,13 @@ class Tanh : public Intrinsic { static std::vector> SupportedVectorTypes() { // F16 via upcast to F32. return { - {Type::S(xla::F16)}, {Type::V(xla::F16, 8)}, {Type::V(xla::F16, 16)}, - {Type::S(xla::F32)}, - - {Type::V(xla::F32, 4)}, {Type::V(xla::F32, 8)}, {Type::V(xla::F32, 16)}, - {Type::S(xla::F64)}, {Type::V(xla::F64, 2)}, {Type::V(xla::F64, 4)}, - {Type::V(xla::F64, 8)}, + {Type::S(xla::F16)}, {Type::V(xla::F16, 2)}, + {Type::V(xla::F16, 4)}, {Type::V(xla::F16, 8)}, + {Type::V(xla::F16, 16)}, {Type::S(xla::F32)}, + {Type::V(xla::F32, 2)}, {Type::V(xla::F32, 4)}, + {Type::V(xla::F32, 8)}, {Type::V(xla::F32, 16)}, + {Type::S(xla::F64)}, {Type::V(xla::F64, 2)}, + {Type::V(xla::F64, 4)}, {Type::V(xla::F64, 8)}, }; } static absl::StatusOr CreateDefinition(llvm::Module* module, From 13a525b5831ff70f781fe3dd361440c059017464 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 9 Dec 2025 08:36:32 -0800 Subject: [PATCH 083/753] [xla:pjrt] Migrate to se::DeviceMemoryAddress PiperOrigin-RevId: 842253793 --- third_party/xla/xla/pjrt/BUILD | 8 ++--- third_party/xla/xla/pjrt/cpu/BUILD | 4 +-- .../xla/xla/pjrt/cpu/abstract_cpu_buffer.cc | 2 +- third_party/xla/xla/pjrt/cpu/cpu_client.cc | 12 ++++---- third_party/xla/xla/pjrt/gpu/BUILD | 6 ++-- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 24 +++++++-------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.h | 4 +-- .../xla/pjrt/gpu/se_gpu_pjrt_client_test.cc | 6 ++-- third_party/xla/xla/pjrt/gpu/tfrt/BUILD | 12 ++++---- ...u_async_host_to_device_transfer_manager.cc | 6 ++-- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_buffer.cc | 10 +++---- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc | 10 +++---- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h | 8 ++--- .../xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc | 6 ++-- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.cc | 2 +- .../xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.h | 2 +- .../xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc | 16 +++++----- .../gpu/tfrt/tracked_gpu_device_buffer.cc | 14 ++++----- .../pjrt/gpu/tfrt/tracked_gpu_device_buffer.h | 18 +++++------ .../tfrt/tracked_gpu_device_buffer_test.cc | 6 ++-- third_party/xla/xla/pjrt/gpu/tfrt/utils.cc | 24 +++++++-------- third_party/xla/xla/pjrt/gpu/tfrt/utils.h | 4 +-- .../xla/xla/pjrt/local_device_state.cc | 4 +-- third_party/xla/xla/pjrt/local_device_state.h | 2 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 30 +++++++++---------- .../xla/pjrt/pjrt_stream_executor_client.h | 10 +++---- third_party/xla/xla/pjrt/se_raw_buffer.cc | 8 ++--- .../xla/xla/pjrt/tracked_device_buffer.cc | 20 ++++++------- .../xla/xla/pjrt/tracked_device_buffer.h | 18 +++++------ .../xla/pjrt/tracked_device_buffer_test.cc | 4 +-- 30 files changed, 150 insertions(+), 150 deletions(-) diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 0f895ff6047f2e..954225f6a10310 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -194,7 +194,7 @@ xla_cc_test( "//xla/client:local_client", "//xla/hlo/testlib:test", "//xla/service:cpu_plugin", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:statusor", @@ -225,7 +225,7 @@ cc_library( ":worker_thread", "//xla:util", "//xla/client:local_client", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/stream_executor:event", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -717,8 +717,8 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 6e7d2fad54a6dc..0756d7beac73bb 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -73,7 +73,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service/cpu:cpu_executable", "//xla/service/cpu:cpu_xfeed", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:env", @@ -208,7 +208,7 @@ cc_library( "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:executable_proto_cc", "//xla/service/llvm_ir:llvm_command_line_options", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "//xla/tsl/platform:env", diff --git a/third_party/xla/xla/pjrt/cpu/abstract_cpu_buffer.cc b/third_party/xla/xla/pjrt/cpu/abstract_cpu_buffer.cc index 3a98e4200f496a..f330c75ca62e13 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_cpu_buffer.cc +++ b/third_party/xla/xla/pjrt/cpu/abstract_cpu_buffer.cc @@ -53,7 +53,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index 5e2f7aa65df9ef..55690711a5fc40 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -117,7 +117,7 @@ limitations under the License. #include "xla/service/maybe_owning_device_address.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" @@ -1275,7 +1275,7 @@ static absl::StatusOr MemoryForAllocation( } else if (allocation.is_constant() && allocation.index() < constants.size()) { - se::DeviceMemoryBase constant = + se::DeviceAddressBase constant = constants[allocation.index()].AsDeviceMemoryBase(); buffer_info.buffer = CpuDeviceMemory::CreateConstantMemory( constant.opaque(), constant.size()); @@ -1624,8 +1624,8 @@ absl::StatusOr PjRtCpuExecutable::ExecuteHelper( buffer_device_mem.reserve(buffer_table.size()); for (const auto& buffer_info : buffer_table) { buffer_device_mem.emplace_back( - se::DeviceMemoryBase(buffer_info.buffer->untyped_data(), - buffer_info.buffer->size_bytes())); + se::DeviceAddressBase(buffer_info.buffer->untyped_data(), + buffer_info.buffer->size_bytes())); } cpu::BufferAllocations allocations(buffer_device_mem); @@ -1768,8 +1768,8 @@ absl::StatusOr PjRtCpuExecutable::ExecuteHelper( buffer_device_mem.reserve(buffer_table.size()); for (const auto& buffer_info : buffer_table) { buffer_device_mem.emplace_back( - se::DeviceMemoryBase(buffer_info.buffer->untyped_data(), - buffer_info.buffer->size_bytes())); + se::DeviceAddressBase(buffer_info.buffer->untyped_data(), + buffer_info.buffer->size_bytes())); } cpu::BufferAllocations allocations(buffer_device_mem); diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 96c8ebbe3c39f3..1b082cc7a37f5f 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -116,9 +116,9 @@ cc_library( "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:gpu_memory_space_assignment", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -246,7 +246,7 @@ xla_test( "//xla/pjrt/proto:compile_options_proto_cc", "//xla/service:platform_util", "//xla/service/gpu:gpu_memory_space_assignment", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/tests:literal_test_util", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 82cbcf932e6dda..e210a480bc74dd 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -102,9 +102,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -206,7 +206,7 @@ static absl::flat_hash_map GetAttrsForDevices( StreamExecutorGpuClient::StreamExecutorGpuClient( std::string platform_name, LocalClient* client, std::vector> devices, - int process_index, std::unique_ptr allocator, + int process_index, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options, @@ -1414,7 +1414,7 @@ BuildLocalDeviceStates(LocalClient* xla_client) { // Constructs a GPU device memory allocator to use, according to the allocator // configuration the client requested. -absl::StatusOr> +absl::StatusOr> GetStreamExecutorGpuDeviceAllocator( se::Platform* platform, const GpuAllocatorConfig& allocator_config, const std::map>& @@ -1849,7 +1849,7 @@ std::vector> BuildLocalDevices( #if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) static absl::Status CheckAlignment(const BufferAllocation& allocation, - se::DeviceMemoryBase buffer, int arg_idx) { + se::DeviceAddressBase buffer, int arg_idx) { const int64_t expected_alignment = [&] { if (allocation.is_entry_computation_parameter()) { return gpu::kEntryParameterAlignBytes; @@ -1887,7 +1887,7 @@ StreamExecutorGpuClient::RunAsync( auto* gpu_exec = tensorflow::down_cast(exec.executable()); const ServiceExecutableRunOptions* run_options = &options_and_stream.first; - se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator(); + se::DeviceAddressAllocator* const memory_allocator = run_options->allocator(); se::StreamExecutor* executor = run_options->stream()->parent(); @@ -1932,7 +1932,7 @@ StreamExecutorGpuClient::RunAsync( absl::Span allocations = gpu_exec->GetAllocations(); - std::vector buffers(allocations.size()); + std::vector buffers(allocations.size()); { tsl::profiler::TraceMe hlo_module_activity( [&] { return std::string("Build buffer allocations"); }, @@ -1940,9 +1940,9 @@ StreamExecutorGpuClient::RunAsync( const int64_t num_buffers = allocations.size(); for (int64_t i = 0; i < num_buffers; ++i) { const BufferAllocation& allocation = *allocations[i]; - se::DeviceMemoryBase& buffer = buffers[i]; + se::DeviceAddressBase& buffer = buffers[i]; if (allocation.is_thread_local()) { - // buffer = se::DeviceMemoryBase{}; + // buffer = se::DeviceAddressBase{}; } else if (allocation.is_entry_computation_parameter()) { int64_t param_no = allocation.parameter_number(); buffer = [&] { @@ -1985,7 +1985,7 @@ StreamExecutorGpuClient::RunAsync( XLA_VLOG_DEVICE(3, device_ordinal) << "Buffer allocations: " << buffer_allocations.ToString(); - std::set buffers_in_result; + std::set buffers_in_result; xla::ShapeTree> results( gpu_exec->result_shape()); @@ -1999,7 +1999,7 @@ StreamExecutorGpuClient::RunAsync( gpu_exec->output_info().at(index); const BufferAllocation* allocation = allocations[output_info.allocation_index]; - se::DeviceMemoryBase result_buffer; + se::DeviceAddressBase result_buffer; XLA_VLOG_DEVICE(4, device_ordinal) << "Looking at: allocation " << output_info.allocation_index @@ -2043,7 +2043,7 @@ StreamExecutorGpuClient::RunAsync( return gpu_exec->VerboseAllocationError(allocated_buffer.status()); } result_buffer = allocated_buffer->Release(); - se::DeviceMemoryBase& aliased_buffer = + se::DeviceAddressBase& aliased_buffer = buffer_allocations.GetMutableDeviceAddress( output_info.allocation_index); CHECK_EQ(aliased_buffer.size(), result_buffer.size()); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index b43592589b9cf3..c56d5757a3c929 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -58,8 +58,8 @@ limitations under the License. #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/shape.h" #include "xla/shape_tree.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/protobuf/coordination_service.pb.h" @@ -109,7 +109,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { StreamExecutorGpuClient( std::string platform_name, LocalClient* client, std::vector> devices, - int process_index, std::unique_ptr allocator, + int process_index, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options, diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index d6283dda89b5fe..787f43b0691a21 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -92,7 +92,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/stream.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -510,7 +510,7 @@ static absl::Status MemsetFromValue( uint32_t pattern; std::memcpy(&pattern, &memset_value->value, sizeof(pattern)); - se::DeviceMemoryBase base = result->device_memory(); + se::DeviceAddressBase base = result->device_memory(); return stream->Memset32(&base, pattern, base.size()); } @@ -559,7 +559,7 @@ static absl::Status MemsetFromAttr( uint32_t pattern; std::memcpy(&pattern, &attr, sizeof(pattern)); - se::DeviceMemoryBase base = result->device_memory(); + se::DeviceAddressBase base = result->device_memory(); return stream->Memset32(&base, pattern, base.size()); } diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD index b6d28c5e744e7f..7cb5b0892be377 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/BUILD +++ b/third_party/xla/xla/pjrt/gpu/tfrt/BUILD @@ -104,10 +104,10 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -217,8 +217,8 @@ xla_test( "//xla/pjrt/plugin/xla_gpu:xla_gpu_client_options", "//xla/pjrt/proto:compile_options_proto_cc", "//xla/service:platform_util", + "//xla/stream_executor:device_address", "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor/cuda:cuda_compute_capability", @@ -272,8 +272,8 @@ cc_library( "//xla:util", "//xla/pjrt:pjrt_client", "//xla/service:shaped_buffer", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:event", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", @@ -317,8 +317,8 @@ xla_cc_test( "//xla/pjrt:pjrt_common", "//xla/service:gpu_plugin", "//xla/service:shaped_buffer", + "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", - "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", # copybara:uncomment "//xla/tsl/framework:allocator", "//xla/tsl/platform:env", diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_async_host_to_device_transfer_manager.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_async_host_to_device_transfer_manager.cc index 7ec8d2dc198fee..44031d7249faa5 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_async_host_to_device_transfer_manager.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_async_host_to_device_transfer_manager.cc @@ -54,9 +54,9 @@ limitations under the License. #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/statusor.h" @@ -275,7 +275,7 @@ TfrtGpuAsyncHostToDeviceTransferManager::TransferRawDataToSubBuffer( staging_buffer = host_memory_allocator->Allocate(transfer_size); } - se::DeviceMemoryBase sub_buffer; + se::DeviceAddressBase sub_buffer; { absl::MutexLock l(mu_); DCHECK_LT(buffer_index, buffer_ptrs_.size()); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_buffer.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_buffer.cc index 65e7a6ace81d0b..6c043547ed5e60 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_buffer.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_buffer.cc @@ -57,9 +57,9 @@ limitations under the License. #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/framework/allocator.h" @@ -583,7 +583,7 @@ Future<> TfrtGpuBuffer::CopyRawToHostFuture(Future dst_future, promise.Set(device_buffer->definition_event().GetError()); return; } - se::DeviceMemoryBase device_memory = device_buffer->buffer()->buffer(); + se::DeviceAddressBase device_memory = device_buffer->buffer()->buffer(); if (offset < 0 || offset > device_memory.size() || device_memory.size() - offset < transfer_size) { LOG(ERROR) << "Copy raw buffer called on buffer size " @@ -596,7 +596,7 @@ Future<> TfrtGpuBuffer::CopyRawToHostFuture(Future dst_future, return; } - se::DeviceMemoryBase sub_buffer; + se::DeviceAddressBase sub_buffer; if (transfer_size < device_memory.size()) { sub_buffer = device_memory.GetByteSlice(offset, transfer_size); } else { @@ -824,7 +824,7 @@ absl::StatusOr> TfrtGpuBuffer::CopyToMemorySpace( auto stream = dst_device->stream(); - se::DeviceMemoryBase dst(allocated_dst_buffer->buffer()); + se::DeviceAddressBase dst(allocated_dst_buffer->buffer()); VLOG(3) << "D2D copy: " << src_buffer->buffer().opaque() << " -> " << dst.opaque() << " (" << src_buffer->buffer().size() << " bytes)"; diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc index eec0d99679e068..f63c05f5aa8b77 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc @@ -94,9 +94,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -148,7 +148,7 @@ TfrtGpuClient::TfrtGpuClient( std::vector> devices, bool should_stage_host_to_device_transfers, bool abort_collectives_on_failure, - MaybeOwning allocator, + MaybeOwning allocator, std::unique_ptr host_memory_allocator, std::unique_ptr gpu_run_options, std::shared_ptr kv_store, @@ -437,7 +437,7 @@ TfrtGpuClient::CreateViewOfDeviceBuffer( CHECK_EQ(memory_space->devices().size(), 1); auto* device = memory_space->devices().front(); size_t byte_size = ShapeUtil::ByteSizeOf(shape); - se::DeviceMemoryBase device_memory(device_ptr, byte_size); + se::DeviceAddressBase device_memory(device_ptr, byte_size); auto non_owning_buffer = GpuDeviceMemory(device_memory); auto buffer_async_value_ref = tsl::MakeAvailableAsyncValueRef( @@ -972,7 +972,7 @@ absl::StatusOr> TfrtGpuClient::BufferFromHostBuffer( }); auto stream = device->stream(); - se::DeviceMemoryBase dest = gpu_buffer->buffer(); + se::DeviceAddressBase dest = gpu_buffer->buffer(); VLOG(3) << "H2D copy: " << src_buf << " -> " << dest.opaque() << " (" << packed_size << " bytes) on device " << device->DebugString(); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h index 41e95484b084a7..88bed1881f355f 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client.h @@ -63,7 +63,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/platform/threadpool.h" #include "xla/xla.pb.h" @@ -119,7 +119,7 @@ class TfrtGpuClient final : public PjRtClient { std::vector> devices, bool should_stage_host_to_device_transfers, bool abort_collectives_on_failure, - MaybeOwning allocator, + MaybeOwning allocator, std::unique_ptr host_memory_allocator, std::unique_ptr gpu_run_options, std::shared_ptr kv_store, @@ -156,7 +156,7 @@ class TfrtGpuClient final : public PjRtClient { xla::LocalClient* xla_client() const { return xla_client_; } - se::DeviceMemoryAllocator* allocator() { return allocator_.get_mutable(); } + se::DeviceAddressAllocator* allocator() { return allocator_.get_mutable(); } bool should_stage_host_to_device_transfers() const { return should_stage_host_to_device_transfers_; @@ -337,7 +337,7 @@ class TfrtGpuClient final : public PjRtClient { // Device memory allocator. If owned, the allocator must outlive the devices, // because it is the device destructor that waits for any outstanding work to // complete. - MaybeOwning allocator_; + MaybeOwning allocator_; // Allocator to be used for staging memory transfers to devices. std::unique_ptr host_memory_allocator_; diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc index 3e5fcc20deb231..c078751882b00c 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_client_test.cc @@ -77,8 +77,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/tests/literal_test_util.h" @@ -386,7 +386,7 @@ static absl::Status MemsetFromValue( uint32_t pattern; std::memcpy(&pattern, &memset_value->value, sizeof(pattern)); - se::DeviceMemoryBase base = result->device_memory(); + se::DeviceAddressBase base = result->device_memory(); return stream->Memset32(&base, pattern, base.size()); } @@ -434,7 +434,7 @@ static absl::Status MemsetFromAttr( uint32_t pattern; std::memcpy(&pattern, &attr, sizeof(pattern)); - se::DeviceMemoryBase base = result->device_memory(); + se::DeviceAddressBase base = result->device_memory(); return stream->Memset32(&base, pattern, base.size()); } diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.cc index 588227432b216a..96204308bc2384 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.cc @@ -55,8 +55,8 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/transfer_manager.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.h b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.h index 95961906b1ace8..97707c3690fb06 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_device.h @@ -46,7 +46,7 @@ limitations under the License. #include "xla/pjrt/semaphore.h" #include "xla/service/hlo.pb.h" #include "xla/service/transfer_manager.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc index 1c97ab898cbd21..5e84506057c524 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc @@ -76,9 +76,9 @@ limitations under the License. #include "xla/shape_layout.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" @@ -110,7 +110,7 @@ namespace xla { class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { public: TfrtGpuCopyToDeviceStream(int64_t channel_id, se::Stream* stream, - se::DeviceMemoryBase dst, + se::DeviceAddressBase dst, tsl::AsyncValueRef> done) : CopyToDeviceStream(dst.size(), /*granule_bytes=*/1), channel_id_(channel_id), @@ -146,7 +146,7 @@ class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { return Future<>(done_.GetError()); } - se::DeviceMemoryBase dst( + se::DeviceAddressBase dst( reinterpret_cast(dst_.opaque()) + current_bytes_, dst_.size() - current_bytes_); @@ -190,7 +190,7 @@ class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { private: int64_t channel_id_; se::Stream* stream_; - se::DeviceMemoryBase dst_; + se::DeviceAddressBase dst_; // Async value will become available after we'll submit the last memcpy // operation, and the event will be recorded on the stream. @@ -771,13 +771,13 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( if (result_is_tuple) { for (int i = 0; i < output_buffers.size(); ++i) { ScopedShapedBuffer tuple_buffer = output.TakeSubTree({i}); - stream_executor::DeviceMemoryBase* elem = + stream_executor::DeviceAddressBase* elem = tuple_buffer.buffers().mutable_element({}); VLOG(3) << "untuple: output_buffers[" << i << "].emplace: " << elem->opaque(); output_buffers[i].emplace(stream_executor::OwningDeviceMemory( *elem, device->local_device_id().value(), client->allocator())); - *elem = se::DeviceMemoryBase(); + *elem = se::DeviceAddressBase(); } } else { CHECK_EQ(output_buffers.size(), 1); @@ -785,7 +785,7 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( VLOG(3) << "output_buffers[0].emplace: " << elem->opaque(); output_buffers.front().emplace(stream_executor::OwningDeviceMemory( *elem, device->local_device_id().value(), client->allocator())); - *elem = se::DeviceMemoryBase(); + *elem = se::DeviceAddressBase(); } // Set the scheduled event to concrete to indicate that the scheduling diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.cc index 32543f080947b1..3d49f9d7a16823 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -44,7 +44,7 @@ ShapedBuffer GpuDeviceMemory::AsShapedBuffer(const Shape& on_device_shape, const PjRtDevice* device) const { ShapedBuffer shaped_buffer(on_device_shape, device->local_device_id().value(), device->local_hardware_id().value()); - ShapeTree::iterator iterator = + ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); CHECK(iterator != shaped_buffer.buffers().end()); iterator->second = buffer_; @@ -60,19 +60,19 @@ void GpuDeviceMemory::SetUnOwned() { } absl::StatusOr GpuDeviceMemory::Allocate( - se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size) { + se::DeviceAddressAllocator* allocator, int device_ordinal, size_t size) { return Allocate(allocator, device_ordinal, size, static_cast(se::MemoryType::kDevice)); } absl::StatusOr GpuDeviceMemory::Allocate( - se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size, + se::DeviceAddressAllocator* allocator, int device_ordinal, size_t size, int64_t memory_space) { if (size == 0) { - return GpuDeviceMemory(se::DeviceMemoryBase()); + return GpuDeviceMemory(se::DeviceAddressBase()); } TF_ASSIGN_OR_RETURN( - stream_executor::OwningDeviceMemory memory, + stream_executor::ScopedDeviceAddress memory, allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/true, memory_space)); return GpuDeviceMemory(std::move(memory)); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h index 19c949075f320d..71abf7139016dd 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer.h @@ -29,8 +29,8 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/event.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -47,11 +47,11 @@ class GpuDeviceMemory { GpuDeviceMemory& operator=(GpuDeviceMemory&& other) = default; // Creates non-owning GPU device memory from a raw data pointer. - explicit GpuDeviceMemory(stream_executor::DeviceMemoryBase buffer) + explicit GpuDeviceMemory(stream_executor::DeviceAddressBase buffer) : buffer_(buffer) {} // Creates owning GPU device memory from an owned data pointer. - explicit GpuDeviceMemory(stream_executor::OwningDeviceMemory buffer) + explicit GpuDeviceMemory(stream_executor::ScopedDeviceAddress buffer) : owning_buffer_(std::move(buffer)), buffer_(*owning_buffer_) {} ShapedBuffer AsShapedBuffer(const Shape& on_device_shape, @@ -62,19 +62,19 @@ class GpuDeviceMemory { // Allocates raw owning memory. static absl::StatusOr Allocate( - se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size); + se::DeviceAddressAllocator* allocator, int device_ordinal, size_t size); static absl::StatusOr Allocate( - se::DeviceMemoryAllocator* allocator, int device_ordinal, size_t size, + se::DeviceAddressAllocator* allocator, int device_ordinal, size_t size, int64_t memory_space); - stream_executor::DeviceMemoryBase buffer() const { return buffer_; } + stream_executor::DeviceAddressBase buffer() const { return buffer_; } size_t size_bytes() const { return buffer_.size(); } bool owns_data() const { return !owning_buffer_.is_null(); } private: - stream_executor::OwningDeviceMemory owning_buffer_; - se::DeviceMemoryBase buffer_; + stream_executor::ScopedDeviceAddress owning_buffer_; + se::DeviceAddressBase buffer_; }; // Class that represents a GPU buffer. It optionally owns the buffer. It also diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer_test.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer_test.cc index 4c0020c87b2329..7961f01d17b439 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tracked_gpu_device_buffer_test.cc @@ -39,8 +39,8 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/env.h" @@ -65,11 +65,11 @@ class TestAllocator : public se::DeviceAddressAllocator { absl::StatusOr> Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) override { - const se::DeviceMemoryBase base(kOpaque, size); + const se::DeviceAddressBase base(kOpaque, size); return stream_executor::ScopedDeviceAddress(base, 0, this); } absl::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) override { + se::DeviceAddressBase mem) override { return absl::OkStatus(); } absl::StatusOr GetStream(int device_ordinal) override { diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/utils.cc b/third_party/xla/xla/pjrt/gpu/tfrt/utils.cc index d6a067722f3c85..a59dd22155ddb6 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/utils.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/utils.cc @@ -83,10 +83,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/integrations/tf_allocator_adapter.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" @@ -307,7 +307,7 @@ absl::flat_hash_map GetAttrsForDevices( class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { public: TfrtGpuCopyToDeviceStream(int64_t channel_id, se::Stream* stream, - se::DeviceMemoryBase dst, + se::DeviceAddressBase dst, tsl::AsyncValueRef> done) : CopyToDeviceStream(dst.size(), /*granule_bytes=*/1), channel_id_(channel_id), @@ -343,7 +343,7 @@ class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { return Future<>(done_.GetError()); } - se::DeviceMemoryBase dst( + se::DeviceAddressBase dst( reinterpret_cast(dst_.opaque()) + current_bytes_, dst_.size() - current_bytes_); @@ -387,7 +387,7 @@ class TfrtGpuCopyToDeviceStream : public CopyToDeviceStream { private: int64_t channel_id_; se::Stream* stream_; - se::DeviceMemoryBase dst_; + se::DeviceAddressBase dst_; // Async value will become available after we'll submit the last memcpy // operation, and the event will be recorded on the stream. @@ -401,7 +401,7 @@ SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( // Check if we have callbacks registered for the given replica. if (replica >= options.send_callbacks.size()) { return [replica](int64_t channel_id, se::Stream*, const Shape&, - const se::DeviceMemoryBase&, + const se::DeviceAddressBase&, const absl::flat_hash_map&) { return Internal( "Don't send a buffer to the channel_id=%d, there was no send " @@ -415,7 +415,7 @@ SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( return [callbacks, thread_pool]( int64_t channel_id, se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& src, + const se::DeviceAddressBase& src, const absl::flat_hash_map&) -> absl::StatusOr>> { VLOG(4) << "Send " << src.size() << " bytes to channel #" << channel_id @@ -490,7 +490,7 @@ RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( // Check if we have callbacks registered for the given replica. if (replica >= options.send_callbacks.size()) { return [replica](int64_t channel_id, se::Stream*, const Shape&, - se::DeviceMemoryBase*, + se::DeviceAddressBase*, const absl::flat_hash_map&) { return InvalidArgument( "Failed to receive a buffer from the channel_id=%d, there was no " @@ -503,7 +503,7 @@ RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( absl::Span callbacks = options.recv_callbacks[replica]; return [callbacks](int64_t channel_id, se::Stream* stream, const Shape& shape, - se::DeviceMemoryBase* dst, + se::DeviceAddressBase* dst, const absl::flat_hash_map&) -> absl::StatusOr>> { VLOG(4) << "Recv from channel #" << channel_id @@ -650,7 +650,7 @@ absl::StatusOr> CreateAllocatorForDevice( } } -absl::StatusOr> CreateDeviceAllocator( +absl::StatusOr> CreateDeviceAllocator( LocalClient* xla_client, const GpuAllocatorConfig& allocator_config, const std::vector>& devices) { if (allocator_config.kind == GpuAllocatorConfig::Kind::kPlatform) { @@ -660,7 +660,7 @@ absl::StatusOr> CreateDeviceAllocator( << "collective_memory_size is non-zero, but allocator kind is set " "to \"platform\". Collective memory will not be allocated."; } - return MaybeOwning( + return MaybeOwning( xla_client->backend().memory_allocator()); } @@ -697,7 +697,7 @@ absl::StatusOr> CreateDeviceAllocator( /*memory_space=*/static_cast(se::MemoryType::kHost), executor->device_ordinal(), executor->GetPlatform()); } - return MaybeOwning( + return MaybeOwning( std::make_unique(xla_client->platform(), std::move(allocators))); } diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/utils.h b/third_party/xla/xla/pjrt/gpu/tfrt/utils.h index 9fdf52226cecba..c7599bd4967d97 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/utils.h +++ b/third_party/xla/xla/pjrt/gpu/tfrt/utils.h @@ -53,7 +53,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -154,7 +154,7 @@ std::vector> InitializeMemorySpaces( absl::StatusOr> CreateAllocatorForDevice( se::StreamExecutor* executor, const GpuAllocatorConfig& allocator_config); -absl::StatusOr> CreateDeviceAllocator( +absl::StatusOr> CreateDeviceAllocator( LocalClient* xla_client, const GpuAllocatorConfig& allocator_config, const std::vector>& devices); diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index b3e16c8e8f20ab..a1812c63ec19a0 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/pjrt/buffer_sequencing_event.h" #include "xla/pjrt/worker_thread.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -177,7 +177,7 @@ absl::Status LocalDeviceState::SynchronizeAllActivity() { absl::Status LocalDeviceState::ThenMemcpyDeviceToDevice( se::Stream* transfer_stream, se::Stream* dst_stream, - se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) { + se::DeviceAddressBase src_buffer, se::DeviceAddressBase dst_buffer) { // The default implementation simply calls MemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. diff --git a/third_party/xla/xla/pjrt/local_device_state.h b/third_party/xla/xla/pjrt/local_device_state.h index 675b6b81459f05..38ca812e589c74 100644 --- a/third_party/xla/xla/pjrt/local_device_state.h +++ b/third_party/xla/xla/pjrt/local_device_state.h @@ -168,7 +168,7 @@ class LocalDeviceState { // Enqueues a copy of `src_buffer` to `dst_buffer` onto `transfer_stream`. virtual absl::Status ThenMemcpyDeviceToDevice( se::Stream* transfer_stream, se::Stream* dst_stream, - se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer); + se::DeviceAddressBase src_buffer, se::DeviceAddressBase dst_buffer); WorkerThread* execute_thread() const { return execute_thread_.get(); } diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 4c175b7390e14c..e342a586863001 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -147,8 +147,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" @@ -275,7 +275,7 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient( std::vector> devices, int process_index, std::vector> memory_spaces, - std::unique_ptr allocator, + std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options) @@ -730,7 +730,7 @@ PjRtStreamExecutorClient::LinearizeHostBufferInto( // memory that has already been allocated, and a possible Event // allocation. - se::DeviceMemoryBase device_memory = + se::DeviceAddressBase device_memory = tensorflow::down_cast( raw_buffer.get()) ->device_buffer() @@ -904,7 +904,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( auto* device = memory_space->devices().front(); auto buffer = RawSEDeviceMemory::CreateForeign( - se::DeviceMemoryBase(device_ptr, ShapeUtil::ByteSizeOf(shape)), + se::DeviceAddressBase(device_ptr, ShapeUtil::ByteSizeOf(shape)), std::move(on_delete_callback)); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, @@ -1139,7 +1139,7 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, absl::Span py_buffers, absl::Span device_buffers, int device_ordinal) { - se::DeviceMemoryAllocator* allocator = client->allocator(); + se::DeviceAddressAllocator* allocator = client->allocator(); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); @@ -1190,7 +1190,7 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, } CHECK(input_iterator == iterator_end); - std::vector elements; + std::vector elements; size_t num_elements = ShapeUtil::TupleElementCount(tupled_parameter_shape); elements.reserve(num_elements); for (int64_t i = 0; i < num_elements; ++i) { @@ -1442,7 +1442,7 @@ static SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( // Check if we have callbacks registered for the given replica. if (replica >= options.send_callbacks.size()) { return [replica](int64_t channel_id, se::Stream*, const Shape&, - const se::DeviceMemoryBase&, + const se::DeviceAddressBase&, const absl::flat_hash_map&) { return Internal( "Don't send a buffer to the channel_id=%d, there was no send " @@ -1456,7 +1456,7 @@ static SendDeviceMemoryFunction ConvertSendCallbacksToSendFunction( return [callbacks, thread_pool]( int64_t channel_id, se::Stream* stream, const Shape& shape, - const se::DeviceMemoryBase& src, + const se::DeviceAddressBase& src, const absl::flat_hash_map&) -> absl::StatusOr>> { VLOG(3) << "Send " << src.size() << " bytes to channel #" << channel_id @@ -1525,7 +1525,7 @@ namespace { class StreamExecutorCopyToDeviceStream : public CopyToDeviceStream { public: StreamExecutorCopyToDeviceStream( - int64_t channel_id, se::Stream* stream, se::DeviceMemoryBase dst, + int64_t channel_id, se::Stream* stream, se::DeviceAddressBase dst, AsyncValueRef> done) : CopyToDeviceStream(dst.size(), /*granule_bytes=*/1), channel_id_(channel_id), @@ -1562,7 +1562,7 @@ class StreamExecutorCopyToDeviceStream : public CopyToDeviceStream { return Future<>(done_.GetError()); } - se::DeviceMemoryBase dst( + se::DeviceAddressBase dst( reinterpret_cast(dst_.opaque()) + current_bytes_, dst_.size() - current_bytes_); @@ -1602,7 +1602,7 @@ class StreamExecutorCopyToDeviceStream : public CopyToDeviceStream { private: int64_t channel_id_; se::Stream* stream_; - se::DeviceMemoryBase dst_; + se::DeviceAddressBase dst_; // Async value will become available after we'll submit the last memcpy // operation, and the event will be recorded on the stream. @@ -1615,7 +1615,7 @@ static RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( // Check if we have callbacks registered for the given replica. if (replica >= options.send_callbacks.size()) { return [replica](int64_t channel_id, se::Stream*, const Shape&, - se::DeviceMemoryBase*, + se::DeviceAddressBase*, const absl::flat_hash_map&) { return InvalidArgument( "Failed to receive a buffer from the channel_id=%d, there was no " @@ -1628,7 +1628,7 @@ static RecvDeviceMemoryFunction ConvertRecvCallbacksToRecvFunction( absl::Span callbacks = options.recv_callbacks[replica]; return [callbacks](int64_t channel_id, se::Stream* stream, const Shape& shape, - se::DeviceMemoryBase* dst, + se::DeviceAddressBase* dst, const absl::flat_hash_map&) -> absl::StatusOr>> { VLOG(3) << "Recv from channel #" << channel_id @@ -1691,7 +1691,7 @@ PjRtStreamExecutorClient::RunAsync( xla::ShapeTree> results( ssb.on_device_shape()); auto it = results.begin(); - se::DeviceMemoryAllocator* allocator = ssb.memory_allocator(); + se::DeviceAddressAllocator* allocator = ssb.memory_allocator(); ShapedBuffer released_ssb = ssb.release(); for (auto& buf : released_ssb.buffers()) { CHECK(it != results.end()); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 3e543724c182aa..4b656c48fc2517 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -66,7 +66,7 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/framework/allocator.h" @@ -237,7 +237,7 @@ class PjRtStreamExecutorClient : public CommonPjRtClient { std::vector> devices, int process_index, std::vector> memory_spaces, - std::unique_ptr allocator, + std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, std::unique_ptr gpu_run_options); @@ -340,7 +340,7 @@ class PjRtStreamExecutorClient : public CommonPjRtClient { ->local_device_state(); } LocalClient* client() const { return client_; } - se::DeviceMemoryAllocator* allocator() const { return allocator_; } + se::DeviceAddressAllocator* allocator() const { return allocator_; } tsl::Allocator* host_memory_allocator() const { return host_memory_allocator_.get(); } @@ -488,8 +488,8 @@ class PjRtStreamExecutorClient : public CommonPjRtClient { // Device memory allocator. If owned, the allocator must outlive the devices, // because it is the device destructor that waits for any outstanding work to // complete. - se::DeviceMemoryAllocator* allocator_; - std::unique_ptr owned_allocator_; + se::DeviceAddressAllocator* allocator_; + std::unique_ptr owned_allocator_; // Includes all devices, including non-local devices on multi-host platforms. std::vector> owned_devices_; diff --git a/third_party/xla/xla/pjrt/se_raw_buffer.cc b/third_party/xla/xla/pjrt/se_raw_buffer.cc index 1d5fc0516f7e10..4ba31cb16cb1d9 100644 --- a/third_party/xla/xla/pjrt/se_raw_buffer.cc +++ b/third_party/xla/xla/pjrt/se_raw_buffer.cc @@ -43,7 +43,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/generic_transfer_manager.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" @@ -136,7 +136,7 @@ PjRtStreamExecutorRawBuffer::CopyRawHostToDeviceAndReturnEvent( local_device = local_device_, stream, src, offset, transfer_size, buf = tsl::FormRef(this)]() mutable { - se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem(); + se::DeviceAddressBase sub_buffer = buf->device_buffer_->mem(); if (transfer_size < sub_buffer.size()) { sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size); } @@ -196,7 +196,7 @@ PjRtStreamExecutorRawBuffer::CopyRawDeviceToHostAndReturnEvent( local_device = local_device_, stream, dst, offset, transfer_size, buf = tsl::FormRef(this)]() mutable { - se::DeviceMemoryBase sub_buffer = buf->device_buffer_->mem(); + se::DeviceAddressBase sub_buffer = buf->device_buffer_->mem(); if (transfer_size < sub_buffer.size()) { sub_buffer = sub_buffer.GetByteSlice(offset, transfer_size); } @@ -248,7 +248,7 @@ ShapedBuffer PjRtStreamExecutorRawBuffer::AsShapedBuffer( auto* device = memory_space()->devices()[0]; ShapedBuffer shaped_buffer(shape, device->local_device_id().value(), device->local_hardware_id().value()); - ShapeTree::iterator iterator = + ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); if (device_buffer_) { CHECK(iterator != shaped_buffer.buffers().end()); diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index f5595d2ea39040..31f668f7baf51d 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -42,8 +42,8 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" @@ -57,7 +57,7 @@ ShapedBuffer RawSEDeviceMemory::AsShapedBuffer( PjRtDevice* device, const Shape& on_device_shape) const { ShapedBuffer shaped_buffer(on_device_shape, device->local_device_id().value(), device->local_hardware_id().value()); - ShapeTree::iterator iterator = + ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); CHECK(iterator != shaped_buffer.buffers().end()); iterator->second = mem(); @@ -68,9 +68,9 @@ ShapedBuffer RawSEDeviceMemory::AsShapedBuffer( class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { public: - AllocatedRawSEDeviceMemory(se::DeviceMemoryBase value, + AllocatedRawSEDeviceMemory(se::DeviceAddressBase value, LocalDeviceState* local_device, - se::DeviceMemoryAllocator* allocator) + se::DeviceAddressAllocator* allocator) : RawSEDeviceMemory(value), allocator_(allocator), local_device_(local_device) { @@ -103,21 +103,21 @@ class AllocatedRawSEDeviceMemory : public RawSEDeviceMemory { } private: - se::DeviceMemoryAllocator* allocator_; + se::DeviceAddressAllocator* allocator_; LocalDeviceState* local_device_; size_t sync_point_ = std::numeric_limits::max(); }; tsl::AsyncValueRef RawSEDeviceMemory::Create( - se::DeviceMemoryBase value, LocalDeviceState* local_device, - se::DeviceMemoryAllocator* allocator) { + se::DeviceAddressBase value, LocalDeviceState* local_device, + se::DeviceAddressAllocator* allocator) { return tsl::MakeAvailableAsyncValueRef( value, local_device, allocator); } class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { public: - ForeignRawSEDeviceMemory(se::DeviceMemoryBase value, + ForeignRawSEDeviceMemory(se::DeviceAddressBase value, absl::AnyInvocable on_delete_callback) : RawSEDeviceMemory(value), on_delete_callback_(std::move(on_delete_callback)) {} @@ -133,7 +133,7 @@ class ForeignRawSEDeviceMemory : public RawSEDeviceMemory { }; tsl::AsyncValueRef RawSEDeviceMemory::CreateForeign( - se::DeviceMemoryBase value, + se::DeviceAddressBase value, absl::AnyInvocable on_delete_callback) { return tsl::MakeAvailableAsyncValueRef( value, std::move(on_delete_callback)); diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.h b/third_party/xla/xla/pjrt/tracked_device_buffer.h index 62b36de4923881..7bce98bf6fa0a8 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.h +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.h @@ -43,8 +43,8 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/threadpool.h" @@ -53,11 +53,11 @@ namespace xla { class RawSEDeviceMemory { public: - explicit RawSEDeviceMemory(se::DeviceMemoryBase value) : value_(value) {} + explicit RawSEDeviceMemory(se::DeviceAddressBase value) : value_(value) {} virtual ~RawSEDeviceMemory() = default; - const se::DeviceMemoryBase& mem() const { return value_; } + const se::DeviceAddressBase& mem() const { return value_; } void* opaque() const { return value_.opaque(); } @@ -70,10 +70,10 @@ class RawSEDeviceMemory { const Shape& on_device_shape) const; static tsl::AsyncValueRef Create( - se::DeviceMemoryBase value, LocalDeviceState* local_device, - se::DeviceMemoryAllocator* allocator); + se::DeviceAddressBase value, LocalDeviceState* local_device, + se::DeviceAddressAllocator* allocator); static tsl::AsyncValueRef CreateForeign( - se::DeviceMemoryBase value, + se::DeviceAddressBase value, absl::AnyInvocable on_delete_callback); // Returns a definition event (or nullptr if the definition is known to be in @@ -84,7 +84,7 @@ class RawSEDeviceMemory { } private: - se::DeviceMemoryBase value_; + se::DeviceAddressBase value_; }; // Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it @@ -124,7 +124,7 @@ class TrackedDeviceBuffer : public AbstractTrackedDeviceBuffer { ShapeTree::iterator* iterator, const ShapeTree::iterator& end, ExecutionInput* execution_input, - se::DeviceMemoryAllocator* allocator) const; + se::DeviceAddressAllocator* allocator) const; const absl::InlinedVector& definition_events() const { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index f4d2b8664df143..d5bec6ba286977 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/platform/statusor.h" @@ -114,7 +114,7 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto b_buffer, MakeArray(b_shape, client)); TF_ASSERT_OK_AND_ASSIGN(auto c_buffer, MakeArray(c_shape, client)); - std::vector expected_buffer_sequence = { + std::vector expected_buffer_sequence = { a_buffer->mem(), b_buffer->mem(), c_buffer->mem()}; ShapedBuffer shaped_a = a_buffer->AsShapedBuffer( &device, From 33c6f22e73c7290bc43d013033bfd539d98b56ba Mon Sep 17 00:00:00 2001 From: Haibo Huang Date: Tue, 9 Dec 2025 09:22:36 -0800 Subject: [PATCH 084/753] Fix use-after-free in PjRtCApiClient `BufferMemoryLayoutData` and `device_layout_list` should be alive until the return of the api call. PiperOrigin-RevId: 842271572 --- .../pjrt/c_api_client/pjrt_c_api_client.cc | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/pjrt/c_api_client/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/c_api_client/pjrt_c_api_client.cc index 3d579a45509817..a173d3c87e0674 100644 --- a/third_party/xla/xla/pjrt/c_api_client/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/c_api_client/pjrt_c_api_client.cc @@ -1381,30 +1381,33 @@ PjRtCApiClient::CreateBuffersForAsyncHostToDevice( PJRT_Client_CreateBuffersForAsyncHostToDevice_Args_STRUCT_SIZE; args.extension_start = nullptr; args.client = c_client_.get(); + args.num_shape_specs = shape_specs.size(); - args.shape_specs = new PJRT_ShapeSpec[shape_specs.size()]; - absl::Cleanup cleanup = - absl::MakeCleanup([&args] { delete[] args.shape_specs; }); - const ShapeSpec* iterator = shape_specs.begin(); - for (int i = 0; i < shape_specs.size(); ++i) { - args.shape_specs[i] = pjrt::ConvertToPjRtShapeSpec(*(iterator++)); + absl::InlinedVector c_shape_specs; + c_shape_specs.reserve(shape_specs.size()); + for (const ShapeSpec& shape_spec : shape_specs) { + c_shape_specs.push_back(pjrt::ConvertToPjRtShapeSpec(shape_spec)); } + args.shape_specs = c_shape_specs.data(); + + absl::InlinedVector layout_data_list; + absl::InlinedVector device_layout_list; if (device_layouts.has_value()) { args.num_device_layouts = device_layouts->size(); - auto device_layout_list = - std::make_unique>( - device_layouts->size()); + device_layout_list.reserve(device_layouts->size()); + layout_data_list.reserve(device_layouts->size()); for (int i = 0; i < device_layouts->size(); ++i) { if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { const Layout& layout = (*device_layouts)[i].value(); TF_ASSIGN_OR_RETURN(pjrt::BufferMemoryLayoutData c_layout_data, pjrt::ConvertToBufferMemoryLayoutData(layout)); - device_layout_list->emplace_back(&(c_layout_data.c_layout)); + layout_data_list.push_back(std::move(c_layout_data)); + device_layout_list.emplace_back(&(layout_data_list.back().c_layout)); } else { - device_layout_list->emplace_back(nullptr); + device_layout_list.emplace_back(nullptr); } } - args.device_layouts = device_layout_list->data(); + args.device_layouts = device_layout_list.data(); } else { args.num_device_layouts = 0; args.device_layouts = nullptr; From 8d5f52bc3a386e1acc9f1d2753a291cf7199aade Mon Sep 17 00:00:00 2001 From: Mikhail Goncharov Date: Tue, 9 Dec 2025 09:53:15 -0800 Subject: [PATCH 085/753] [XLA:GPU] split dimensions greedly including ones in CalculateBitcastOfTransposeImpl We have not hanlded the case when bitcast introcduces a new 1-size dimension in case like a = f32[6,7] transpose(f32[7,6]), dims={1,0} b = f32[6,1,7] bitcast(a) as this 1-size dimension could teoretically be moved anywhere in the hosted expression c = f32[1,7,6] bitcase(f32[7,6]) # or c = f32[7,1,6] bitcase(f32[7,6]) # or c = f32[7,6,1] bitcase(f32[7,6]) b = f32[6,1,7] transpose(c), dims=... by using a "greedy" version of CommonFactors that does not produce grouping like [] -> [1] or [1] -> [] we now handle this case (picking group [6,1] -> [6] mapping). PiperOrigin-RevId: 842282725 --- .../xla/xla/service/gpu/transforms/BUILD | 5 +- .../gpu/transforms/nest_gemm_fusion.cc | 35 ++- .../service/gpu/transforms/nest_gemm_fusion.h | 19 +- .../gpu/transforms/nest_gemm_fusion_test.cc | 249 ++++++++++++------ 4 files changed, 218 insertions(+), 90 deletions(-) diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index e4ed891c6391b9..1acb680d9e73d7 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -1857,7 +1857,6 @@ cc_library( "//xla/codegen/tiling:symbolic_tile_analysis", "//xla/codegen/tiling:symbolic_tiled_hlo_instruction", "//xla/codegen/tiling:tiling_specification", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/simplifiers:hlo_dce", @@ -1898,18 +1897,18 @@ xla_cc_test( deps = [ ":nest_gemm_fusion", "//xla:xla_proto_cc", - "//xla/hlo/analysis:symbolic_expr", "//xla/hlo/ir:hlo", "//xla/hlo/testlib:filecheck", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:pattern_matcher_gmock", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/stream_executor:device_description", "//xla/stream_executor/cuda:cuda_compute_capability", - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc index dc972406913fc3..f9812d96ca0747 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/transforms/nest_gemm_fusion.h" +#include #include #include #include @@ -658,8 +659,11 @@ absl::StatusOr CalculateBitcastOfTransposeImpl( // Maps logical operand dimension index to the physical dimension index. llvm::SmallVector operand_inv_layout = GetInversePermutation(operand_shape.layout().minor_to_major()); - auto factors = CommonFactors(GetPhysicalDimensions(result_shape), - GetPhysicalDimensions(transpose_shape)); + + const absl::InlinedVector, 8> factors = + ::xla::gpu::detail::CommonFactorsMergingTrivialRanges( + GetPhysicalDimensions(result_shape), + GetPhysicalDimensions(transpose_shape)); for (int64_t i = 1; i < factors.size(); ++i) { auto [result_from, transpose_from] = factors[i - 1]; auto [result_to, transpose_to] = factors[i]; @@ -1348,5 +1352,32 @@ absl::StatusOr FindBlockLevelParameters( "Couldn't find output tile sizes that satisfy ", tiled_dot.ToString())); } +absl::InlinedVector, 8> +CommonFactorsMergingTrivialRanges(absl::Span a, + absl::Span b) { + // CommonFactors does what we need but it also creates empty groups with + // product of 1, e.g. `[1] -> []` or `[] -> [1]`. We remove the bounds of + // such ranges to merge them with neighbors. There are many different ways + // to do this, here we continously append ranges to the start of the next + // group unless it is the very last range. + absl::InlinedVector, 8> bounds = + CommonFactors(a, b); + for (size_t i = 0; i + 1 < bounds.size() && bounds.size() > 2;) { + auto [a_start, b_start] = bounds[i]; + auto [a_end, b_end] = bounds[i + 1]; + if (a_start != a_end && b_start != b_end) { + i++; + continue; + } + if (i + 2 == bounds.size()) { + // Very last range - append it to the previous one. + bounds.erase(bounds.begin() + i); + } else { + bounds.erase(bounds.begin() + i + 1); + } + } + return bounds; +} + } // namespace detail } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h index 2d94ad1c4417f6..bc1a54cfadd09a 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h @@ -16,11 +16,16 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ #define XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ +#include +#include + #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/hlo/analysis/symbolic_expr.h" -#include "xla/hlo/ir/hlo_instructions.h" +#include "absl/types/span.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/gpu/matmul_utils.h" @@ -81,6 +86,16 @@ absl::StatusOr FindBlockLevelParameters( mlir::MLIRContext* mlir_context, const se::DeviceDescription& device_description); +// Returns the start indices of consecutive non-overlapping subsequences of `a` +// and `b` with the same product (see `CommonFactors` from `util.h`) grouping +// ranges having product of 1 with neighbors. +// +// For example, if a=[2, 5, 1, 3] and b=[1, 10, 3, 1], the result will be +// {{0, 0}, {2, 2}, {4, 4}}, grouping [2,5] with [1,10] and [1,3] with [3,1]. +absl::InlinedVector, 8> +CommonFactorsMergingTrivialRanges(absl::Span a, + absl::Span b); + } // namespace detail } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc index b353c68c8b2d29..fd81e8ef39895b 100644 --- a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -15,31 +15,38 @@ limitations under the License. #include "xla/service/gpu/transforms/nest_gemm_fusion.h" +#include +#include #include +#include +#include #include #include +#include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/status/status_matchers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "mlir/IR/MLIRContext.h" -#include "xla/hlo/analysis/symbolic_expr.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/pattern_matcher.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" -#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" +using ::absl_testing::IsOkAndHolds; using ::testing::ElementsAre; namespace xla { @@ -105,7 +112,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -162,9 +169,9 @@ ENTRY e { "num_ctas":1}}} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); HloComputation* fusion_computation = module->entry_computation() ->root_instruction() @@ -283,7 +290,7 @@ ENTRY entry { absl::Substitute(hlo, HloOpcodeString(opcode)))); ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -330,7 +337,7 @@ ENTRY entry { absl::Substitute(hlo, HloOpcodeString(opcode)))); ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); const HloInstruction* fusion = nullptr; @@ -375,9 +382,9 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -410,9 +417,9 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -444,15 +451,15 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK: f16[3,11]{1,0} convert( CHECK: f16[3,11]{1,0} fusion( )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -490,9 +497,9 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -526,9 +533,9 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK: ENTRY @@ -536,7 +543,7 @@ ENTRY entry { CHECK: [[fusion:[^ ]+]] = s8[3,11]{1,0:E(4)} fusion({{.*}}) CHECK: ROOT {{.*}} = s8[33]{0:E(4)} bitcast([[fusion]]) )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -570,9 +577,9 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -609,9 +616,9 @@ ENTRY entry_computation { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -645,9 +652,9 @@ ENTRY entry_computation { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -680,9 +687,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -712,9 +719,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -727,7 +734,7 @@ CHECK: ENTRY {{.*}} { CHECK: [[entry_p0:[^ ]+]] = f32[11,1,24,1]{3,2,1,0} parameter(0) CHECK: {{.*}} = f32[264]{0} bitcast([[entry_p0]]) )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -758,7 +765,7 @@ ENTRY e { ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); // We can nest the fusion including the broadcast. - EXPECT_TRUE(NestGemmFusion(device_description_, &mlir_context_) + ASSERT_TRUE(NestGemmFusion(device_description_, &mlir_context_) .Run(module.get()) .ok()); ASSERT_OK(verifier().Run(module.get()).status()); @@ -769,7 +776,7 @@ CHECK: f32[3,4,16]{2,1,0} broadcast CHECK-NEXT: f32[3,64]{1,0} $0 )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -800,7 +807,7 @@ ENTRY e { ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); // We can nest the fusion including the broadcast. - EXPECT_TRUE(NestGemmFusion(device_description_, &mlir_context_) + ASSERT_TRUE(NestGemmFusion(device_description_, &mlir_context_) .Run(module.get()) .ok()); ASSERT_OK(verifier().Run(module.get()).status()); @@ -811,7 +818,7 @@ CHECK: f32[2,3,5]{2,1,0} $0 CHECK-NEXT: f32[2,4,3,5]{3,2,1,0} broadcast )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -843,9 +850,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -854,7 +861,7 @@ CHECK-DAG: {{.*}} = f32[15,77]{1,0} broadcast([[p0]]), dimensions={0} CHECK-DAG: [[br:[^ ]+]] = f32[15]{0} broadcast([[p0]]), dimensions={0} CHECK-DAG: {{.*}} = f32[15,77]{1,0} broadcast([[br]]), dimensions={0} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastsAreHoistedOverBroadcasts) { @@ -883,9 +890,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT(RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -899,7 +906,7 @@ CHECK: [[entry_p0:[^ ]+]] = f32[11,1,24,1]{3,2,1,0} parameter(0) CHECK: {{.*}} = f32[264]{0} bitcast([[entry_p0]]) )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastsLayoutIsPreserved) { @@ -934,9 +941,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT(RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), absl::Substitute(R"( @@ -953,7 +960,7 @@ CHECK: ENTRY {{.*}} { CHECK: {{.*}} = pred[122,5]{0,1} bitcast({{.*}}) )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -985,16 +992,16 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK: bf16[1,2,4,8]{{.*}} broadcast({{.*}}), dimensions={0,3} CHECK: bf16[1,2,4,8]{{.*}} broadcast({{.*}}), dimensions={0,3} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastsAreHoistedUpThroughTransposes) { @@ -1021,9 +1028,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -1031,7 +1038,45 @@ CHECK: ROOT transpose CHECK-SAME: f32[2,3,7]{2,1,0} transpose CHECK-SAME: dimensions={1,2,0} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); +} + +TEST_P(NestGemmFusionReshapeTest, + BitcastsWithSize1DimensionsAreHoistedUpThroughTransposes) { + const HloOpcode opcode = GetParam(); + absl::string_view hlo = R"( +triton_dot { + p0 = f32[7,6] parameter(0) + transpose = f32[6,7] transpose(p0), dimensions={1,0} + bitcast = f32[1,6,7] $0(transpose) + p1 = f32[1,5,7] parameter(1) + ROOT result = f32[1,6,5] dot(bitcast, p1), + lhs_contracting_dims={2}, lhs_batch_dims={0}, + rhs_contracting_dims={2}, rhs_batch_dims={0} +} + +ENTRY e { + p0 = f32[7,6] parameter(0) + p1 = f32[1,5,7] parameter(1) + ROOT result = f32[1,6,5] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: {"block_m":16,"block_n":16,"block_k":8, + "split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}}} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule( + absl::Substitute(hlo, HloOpcodeString(opcode)))); + ASSERT_THAT( + NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), + IsOkAndHolds(true)); + ASSERT_OK(verifier().Run(module.get()).status()); + EXPECT_THAT( + RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( +CHECK: ROOT transpose +CHECK-SAME: f32[1,6,7]{2,1,0} transpose +CHECK-SAME: dimensions={1,2,0} +)"), + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1058,9 +1103,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -1068,7 +1113,7 @@ CHECK: transpose CHECK-SAME: f32[3,2,7]{2,1,0} transpose CHECK-SAME: dimensions={2,0,1} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1095,9 +1140,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT(RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), absl::Substitute(R"( @@ -1105,7 +1150,7 @@ CHECK: f32[2,3,5]{2,1,0} $0 CHECK-NEXT: f32[2,5,3]{2,1,0} transpose )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1134,9 +1179,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); // Checks that transpose is on rank 3 tensor from hoisting bitcast1, not rank // 4 tensor from hoisting bitcast0 first and then failing to hoist bitcast1. @@ -1146,7 +1191,7 @@ CHECK: transpose CHECK-SAME: f16[3,1152,122]{2,1,0} transpose CHECK-SAME: dimensions={0,2,1} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastsAreHoistedDownThroughTransposes) { @@ -1173,9 +1218,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -1183,7 +1228,7 @@ CHECK: ROOT transpose CHECK-SAME: f32[5,2,3]{2,1,0} transpose CHECK-SAME: dimensions={2,0,1} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastsAreHoistedDownThroughBroadcasts) { @@ -1209,9 +1254,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( @@ -1219,7 +1264,7 @@ CHECK: ROOT broadcast CHECK-SAME: f32[3,5,6,2]{2,1,0,3} broadcast CHECK-SAME: dimensions={0,1} )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1246,9 +1291,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT(RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), absl::Substitute(R"( @@ -1256,7 +1301,7 @@ CHECK: f32[2,3,5]{2,1,0} $0(dot) CHECK-NEXT: f32[2,3,5]{2,0,1} broadcast )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, BitcastRootsAreHoistedDown) { @@ -1281,15 +1326,15 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK: ROOT dot )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1318,15 +1363,15 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK: ROOT add = f32[3,5]{1,0} add )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); } TEST_P(NestGemmFusionReshapeTest, @@ -1359,9 +1404,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK-NOT: bitcast @@ -1376,7 +1421,7 @@ CHECK: f32[2,7]{1,0} bitcast(p0 CHECK: result = f32[2,7,15,11]{2,1,0,3} fusion CHECK: ROOT {{.*}} = f32[15,11,14]{0,2,1} bitcast(result) )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -1410,9 +1455,9 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK-NOT: bitcast @@ -1427,7 +1472,7 @@ CHECK: f32[7,3,2]{2,0,1} bitcast(p0 CHECK: result = f32[3,5,2]{2,1,0} fusion CHECK: ROOT {{.*}} = f32[2,3,5]{0,2,1} bitcast(result) )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -1458,16 +1503,16 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( absl::Substitute(hlo, HloOpcodeString(opcode)))); - EXPECT_THAT( + ASSERT_THAT( NestGemmFusion(device_description_, &mlir_context_).Run(module.get()), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); EXPECT_THAT( RunFileCheck(module->ToString(HloPrintOptions::ShortParsable()), R"( CHECK-NOT: bitcast CHECK-NOT: reshape CHECK: ENTRY )"), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -1509,7 +1554,7 @@ CHECK-NOT: bitcast CHECK-NOT: reshape )", HloOpcodeString(opcode))), - absl_testing::IsOkAndHolds(true)); + IsOkAndHolds(true)); ASSERT_OK(verifier().Run(module.get()).status()); } @@ -1521,6 +1566,44 @@ INSTANTIATE_TEST_SUITE_P(NestGemmFusionReshapeTestSuite, return std::string(HloOpcodeString(info.param)); }); +struct CommonFactorsTestCase { + std::vector from, to; + absl::InlinedVector, 8> expected; +}; + +class CommonFactorsMergingTrivialRangesTest + : public ::testing::TestWithParam {}; + +TEST_P(CommonFactorsMergingTrivialRangesTest, Example) { + const CommonFactorsTestCase& test_case = GetParam(); + EXPECT_EQ(test_case.expected, detail::CommonFactorsMergingTrivialRanges( + test_case.from, test_case.to)); +} + +INSTANTIATE_TEST_SUITE_P( + CommonFactorsMergingTrivialRangesTestSuite, + CommonFactorsMergingTrivialRangesTest, + ::testing::Values( + CommonFactorsTestCase{{1}, {}, {{0, 0}, {1, 0}}}, + CommonFactorsTestCase{{}, {1}, {{0, 0}, {0, 1}}}, + CommonFactorsTestCase{{}, {}, {{0, 0}}}, + CommonFactorsTestCase{{1, 2, 0}, {2, 0, 3}, {{0, 0}, {3, 3}}}, + CommonFactorsTestCase{{2, 3, 0}, {1, 0, 1000}, {{0, 0}, {3, 3}}}, + CommonFactorsTestCase{{1, 1, 1}, {1, 1}, {{0, 0}, {1, 1}, {3, 2}}}, + CommonFactorsTestCase{{1, 1, 3}, {3, 1, 1}, {{0, 0}, {3, 3}}}, + CommonFactorsTestCase{{2, 6}, {4, 3}, {{0, 0}, {2, 2}}}, + CommonFactorsTestCase{{1, 2, 6}, {4, 1, 3, 1}, {{0, 0}, {3, 4}}}, + CommonFactorsTestCase{{2, 3, 4, 5}, {6, 20}, {{0, 0}, {2, 1}, {4, 2}}}, + CommonFactorsTestCase{ + {2, 3, 4, 5, 6}, {6, 20, 6}, {{0, 0}, {2, 1}, {4, 2}, {5, 3}}}, + CommonFactorsTestCase{{2, 2, 2, 2}, {4, 4}, {{0, 0}, {2, 1}, {4, 2}}}, + CommonFactorsTestCase{ + {2, 5, 1, 3}, {1, 10, 3, 1}, {{0, 0}, {2, 2}, {4, 4}}}), + [](const ::testing::TestParamInfo& info) { + return absl::StrCat(absl::StrJoin(info.param.from, "_"), "_to_", + absl::StrJoin(info.param.to, "_")); + }); + } // namespace } // namespace gpu } // namespace xla From 19a899a5b8475fa9d72fbc16b09c5a90381d7c18 Mon Sep 17 00:00:00 2001 From: Marcin Radomski Date: Tue, 9 Dec 2025 10:07:23 -0800 Subject: [PATCH 086/753] [XLA:GPU] Make float check more parallel The goal is to count NaNs/infs/zeros in a buffer of floats, and append the results to a BufferDebugLog stored in device memory. This used to be done on a single thread block with poor performance. This CL changes it to a 2-step process: 1. Do partial accumulation into a temporary buffer. 2. Use a second kernel to reduce partial results down into scalars and append them to the log. This also includes some optimizations suggested by gflegar: * Use array-of-structs over struct-of-arrays for __shared__ memory in step 1 * Always use 1024 threads per block to avoid switching at kernel runtime * Read global memory 128bits a time PiperOrigin-RevId: 842289188 --- .../xla/xla/backends/gpu/runtime/BUILD | 7 + .../gpu/runtime/buffer_debug_log_structs.h | 19 + .../gpu/runtime/buffers_float_check_thunk.cc | 65 +++- .../gpu/runtime/buffers_float_check_thunk.h | 13 +- .../runtime/buffers_float_check_thunk_test.cc | 29 +- .../runtime/thunk_buffer_debug_float_check.cc | 56 ++- .../runtime/thunk_buffer_debug_pass_test.cc | 22 +- .../xla/xla/stream_executor/cuda/BUILD | 3 + ...buffer_debug_float_check_kernel_cuda.cu.cc | 342 ++++++++++-------- ...ffer_debug_float_check_kernel_cuda_test.cc | 135 +++++-- .../gpu/buffer_debug_float_check_kernel.h | 29 +- 11 files changed, 495 insertions(+), 225 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/runtime/BUILD b/third_party/xla/xla/backends/gpu/runtime/BUILD index 856ea1bc586623..27b4149b9e25bd 100644 --- a/third_party/xla/xla/backends/gpu/runtime/BUILD +++ b/third_party/xla/xla/backends/gpu/runtime/BUILD @@ -3259,6 +3259,7 @@ cc_library( ":thunk_id", ":thunk_pass_pipeline", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/gpu:ffi", @@ -3287,6 +3288,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_googlesource_code_re2//:re2", + "@eigen_archive//:eigen3", ], ) @@ -3463,9 +3465,13 @@ cc_library( ":buffer_debug_log_structs", ":thunk", "//xla:types", + "//xla:util", "//xla/service:buffer_assignment", "//xla/stream_executor:device_address", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_compute_capability", "//xla/stream_executor/cuda:cuda_platform_id", @@ -3477,6 +3483,7 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", diff --git a/third_party/xla/xla/backends/gpu/runtime/buffer_debug_log_structs.h b/third_party/xla/xla/backends/gpu/runtime/buffer_debug_log_structs.h index 9ff067c00b633d..9d3492ae964f33 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffer_debug_log_structs.h +++ b/third_party/xla/xla/backends/gpu/runtime/buffer_debug_log_structs.h @@ -54,6 +54,25 @@ static_assert(sizeof(BufferDebugLogEntry) == sizeof(uint32_t) * 2); static_assert(offsetof(BufferDebugLogEntry, entry_id) == 0); static_assert(offsetof(BufferDebugLogEntry, value) == sizeof(uint32_t)); +struct FloatCheckResult { + uint32_t nan_count; + uint32_t inf_count; + uint32_t zero_count; + + template + friend void AbslStringify(Sink& sink, const FloatCheckResult& result) { + absl::Format(&sink, "{nan_count: %u, inf_count: %u, zero_count: %u}", + result.nan_count, result.inf_count, result.zero_count); + } +}; + +// The struct layout must match on both host and device. +static_assert(_Alignof(FloatCheckResult) == _Alignof(uint32_t)); +static_assert(sizeof(FloatCheckResult) == sizeof(uint32_t) * 3); +static_assert(offsetof(FloatCheckResult, nan_count) == 0); +static_assert(offsetof(FloatCheckResult, inf_count) == sizeof(uint32_t)); +static_assert(offsetof(FloatCheckResult, zero_count) == sizeof(uint32_t) * 2); + struct BufferDebugFloatCheckEntry { // An ID that uniquely identifies a log entry within a HLO module execution. BufferDebugLogEntryId entry_id; diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc index 6ff174e2a418d2..d6b8b04c70c47a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.cc @@ -15,11 +15,15 @@ limitations under the License. #include "xla/backends/gpu/runtime/buffers_float_check_thunk.h" +#include +#include #include #include #include +#include #include +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -30,14 +34,18 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/buffer_debug_float_check_kernel.h" #include "xla/stream_executor/gpu/buffer_debug_log.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" #include "xla/types.h" +#include "xla/util.h" namespace xla::gpu { @@ -73,15 +81,33 @@ absl::Status BuffersDebugFloatCheckThunk::Initialize( auto kernel_bf16, registry.LoadKernel( params.executor)); + TF_ASSIGN_OR_RETURN( + auto kernel_reduce, + registry.LoadKernel< + se::gpu::BufferDebugAppendReducedFloatCheckResultsKernel>( + params.executor)); kernels_[params.executor] = std::make_unique( - Kernels{std::move(kernel_f32), std::move(kernel_bf16)}); + Kernels{std::move(kernel_f32), std::move(kernel_bf16), + std::move(kernel_reduce)}); + VLOG(1) << "NanCount kernels loaded"; } } - VLOG(1) << "FloatCheck kernel loaded"; return absl::OkStatus(); } +template +se::BlockDim GetBlockDimForBuffer(se::Stream* stream, + se::DeviceMemory buffer, + int64_t max_blocks) { + const int64_t num_elements = buffer.size() / sizeof(T); + const se::DeviceDescription& desc = stream->parent()->GetDeviceDescription(); + const int64_t num_blocks = + std::min(xla::CeilOfRatio(num_elements, desc.threads_per_block_limit()), + max_blocks); + return se::BlockDim(num_blocks); +} + absl::Status BuffersDebugFloatCheckThunk::ExecuteOnStream( const ExecuteParams& params) { se::StreamExecutor* executor = params.stream->parent(); @@ -102,8 +128,13 @@ absl::Status BuffersDebugFloatCheckThunk::ExecuteOnStream( VLOG(1) << "BuffersDebugFloatCheckThunk::ExecuteOnStream"; - const se::ThreadDim thread_dim( - executor->GetDeviceDescription().threads_per_block_limit(), 1, 1); + se::DeviceAddress tmp_ptr( + params.buffer_allocations->GetDeviceAddress(tmp_slice_)); + const size_t tmp_size_elements = + tmp_slice_.size() / sizeof(xla::gpu::FloatCheckResult); + CHECK_GT(tmp_size_elements, 0) + << "tmp_slice_ is too small to hold any results, this should have been " + "caught during initialization"; se::DeviceAddress log_ptr( params.buffer_allocations->GetDeviceAddress(log_slice_)); @@ -111,6 +142,8 @@ absl::Status BuffersDebugFloatCheckThunk::ExecuteOnStream( se::gpu::BufferDebugLog< BufferDebugFloatCheckEntry>::FromDeviceAddressUnchecked(log_ptr); const uint32_t execution_id = execution_count_.fetch_add(1); + // The kernel assumes 1024 threads per block. + const se::ThreadDim thread_dim(1024); for (const auto& [buffer_idx, buffer] : checked_thunk_buffers_) { BufferDebugLogEntryMetadataStore::Metadata metadata{ @@ -130,22 +163,32 @@ absl::Status BuffersDebugFloatCheckThunk::ExecuteOnStream( VLOG(1) << "F32 buffer detected with id: " << entry_id << " and size: " << device_buffer.size(); se::DeviceAddress f32_buffer(device_buffer); - TF_RETURN_IF_ERROR(kernels->f32.Launch( - thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, - f32_buffer, f32_buffer.size(), buffer_debug_log.GetDeviceHeader(), - buffer_debug_log.GetDeviceEntries())); + const se::BlockDim block_dim = GetBlockDimForBuffer( + params.stream, f32_buffer, tmp_size_elements); + TF_RETURN_IF_ERROR( + kernels->f32.Launch(thread_dim, block_dim, params.stream, f32_buffer, + f32_buffer.size(), tmp_ptr, tmp_size_elements)); } else if (buffer_type == PrimitiveType::BF16) { VLOG(1) << "BF16 buffer detected with id: " << entry_id << " and size: " << device_buffer.size(); se::DeviceAddress bf16_buffer(device_buffer); + const se::BlockDim block_dim = GetBlockDimForBuffer( + params.stream, bf16_buffer, tmp_size_elements); TF_RETURN_IF_ERROR(kernels->bf16.Launch( - thread_dim, se::BlockDim(1, 1, 1), params.stream, entry_id, - bf16_buffer, bf16_buffer.size(), buffer_debug_log.GetDeviceHeader(), - buffer_debug_log.GetDeviceEntries())); + thread_dim, block_dim, params.stream, bf16_buffer, bf16_buffer.size(), + tmp_ptr, tmp_size_elements)); } else { VLOG(1) << "Unsupported primitive type for float checking: " << PrimitiveType_Name(buffer_type); + continue; } + + // Operations on the same stream perform in sequence, so at this point the + // results of the previous FloatCheck operation are available. + TF_RETURN_IF_ERROR(kernels->reduce.Launch( + thread_dim, se::BlockDim(1, 1, 1), params.stream, tmp_ptr, + tmp_size_elements, entry_id, buffer_debug_log.GetDeviceHeader(), + buffer_debug_log.GetDeviceEntries())); } return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.h b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.h index 5d2f78e80edb99..f73c9ef305fde6 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.h +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -38,18 +39,21 @@ class BuffersDebugFloatCheckThunk : public Thunk { public: explicit BuffersDebugFloatCheckThunk( ThunkInfo info, const ThunkInfo& checked_thunk_info, - BufferAllocation::Slice log_slice, + BufferAllocation::Slice log_slice, BufferAllocation::Slice tmp_slice, absl::flat_hash_map checked_thunk_buffers, std::shared_ptr metadata_store) : Thunk(Thunk::Kind::kBuffersDebugFloatCheck, std::move(info)), log_slice_(log_slice), + tmp_slice_(tmp_slice), checked_thunk_info_(checked_thunk_info), checked_thunk_buffers_(std::move(checked_thunk_buffers)), metadata_store_(std::move(metadata_store)) {} - absl::Status Initialize(const InitializeParams& params) override; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; + absl::Status Initialize(const InitializeParams& params) override + ABSL_LOCKS_EXCLUDED(kernels_mutex_); + absl::Status ExecuteOnStream(const ExecuteParams& params) override + ABSL_LOCKS_EXCLUDED(kernels_mutex_); std::string ToString(int indent) const override; @@ -67,6 +71,8 @@ class BuffersDebugFloatCheckThunk : public Thunk { struct Kernels { stream_executor::gpu::BufferDebugFloatCheckF32Kernel::KernelType f32; stream_executor::gpu::BufferDebugFloatCheckBf16Kernel::KernelType bf16; + stream_executor::gpu::BufferDebugAppendReducedFloatCheckResultsKernel:: + KernelType reduce; }; absl::Mutex kernels_mutex_; // Each loaded kernel is associated with a specific device (represented by its @@ -79,6 +85,7 @@ class BuffersDebugFloatCheckThunk : public Thunk { kernels_ ABSL_GUARDED_BY(kernels_mutex_); BufferAllocation::Slice log_slice_; + BufferAllocation::Slice tmp_slice_; ThunkInfo checked_thunk_info_; absl::flat_hash_map checked_thunk_buffers_; std::shared_ptr metadata_store_; diff --git a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc index dfb933bce2a4ee..c56538c15a6e39 100644 --- a/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/buffers_float_check_thunk_test.cc @@ -101,17 +101,20 @@ class BuffersDebugFloatCheckThunkTest : public ::testing::Test { TEST_F(BuffersDebugFloatCheckThunkTest, CalculatesNanCounts) { static constexpr size_t kLogSize = BufferDebugLog::RequiredSizeForEntries(10); + static constexpr size_t kTmpSizeElems = 1024; + static constexpr size_t kTmpSizeBytes = kTmpSizeElems * sizeof(uint32_t); static constexpr size_t kInputElems = 1024; static constexpr size_t kInputSizeInBytes = kInputElems * sizeof(float); static constexpr size_t kTotalDeviceMemoryBytes = - kLogSize + kInputSizeInBytes * 2; + kLogSize + kTmpSizeBytes + kInputSizeInBytes * 2; // Setup memory allocations for the log and inputs BufferAllocation alloc(/*index=*/0, /*size=*/kTotalDeviceMemoryBytes, /*color=*/0); int64_t input_offset = kLogSize; BufferAllocation::Slice log_slice(&alloc, /*offset=*/0, kLogSize); - input_offset += kLogSize; + BufferAllocation::Slice tmp_slice(&alloc, /*offset=*/kLogSize, kTmpSizeBytes); + input_offset += kLogSize + kTmpSizeBytes; BufferAllocation::Slice inputs[2]; int64_t input_size_bf16 = kInputElems * sizeof(Eigen::bfloat16); @@ -159,7 +162,7 @@ TEST_F(BuffersDebugFloatCheckThunkTest, CalculatesNanCounts) { Thunk::ThunkInfo checked_thunk_info; checked_thunk_info.thunk_id = ThunkId(123); BuffersDebugFloatCheckThunk thunk( - Thunk::ThunkInfo(), checked_thunk_info, log_slice, + Thunk::ThunkInfo(), checked_thunk_info, log_slice, tmp_slice, {{/*buffer_idx=*/0, inputs[0]}, {/*buffer_idx=*/1, inputs[1]}}, metadata_store); TF_ASSERT_OK(thunk.Initialize(init_params)); @@ -202,8 +205,13 @@ TEST_F(BuffersDebugFloatCheckThunkTest, GTEST_SKIP() << "need at least 2 devices for this test"; } + static constexpr size_t kLogOffset = 0; static constexpr size_t kLogSizeBytes = 1024; + static constexpr size_t kTmpOffset = kLogOffset + kLogSizeBytes; + static constexpr size_t kTmpSizeBytes = 1024 * sizeof(uint32_t); + static constexpr size_t kInputOffset = kTmpOffset + kTmpSizeBytes; static constexpr size_t kInputSizeBytes = 1024; + static constexpr size_t kTotalDeviceMemory = kInputOffset + kInputSizeBytes; struct TestDevice { se::StreamExecutor* executor; @@ -219,7 +227,7 @@ TEST_F(BuffersDebugFloatCheckThunkTest, auto allocator = std::make_unique(executor); BufferAllocations allocations( - {executor->AllocateArray(kLogSizeBytes + kInputSizeBytes)}, + {executor->AllocateArray(kTotalDeviceMemory)}, executor->device_ordinal(), allocator.get()); return TestDevice{std::move(executor), std::move(stream), @@ -227,16 +235,17 @@ TEST_F(BuffersDebugFloatCheckThunkTest, }; TF_ASSERT_OK_AND_ASSIGN(TestDevice device0, setup_device(0)); TF_ASSERT_OK_AND_ASSIGN(TestDevice device1, setup_device(1)); - BufferAllocation allocation(0, kLogSizeBytes + kInputSizeBytes, 0); - BufferAllocation::Slice log_slice(&allocation, 0, kLogSizeBytes); - BufferAllocation::Slice f32_slice(&allocation, kLogSizeBytes, kInputSizeBytes, + BufferAllocation allocation(/*index=*/0, kTotalDeviceMemory, /*color=*/0); + BufferAllocation::Slice log_slice(&allocation, kLogOffset, kLogSizeBytes); + BufferAllocation::Slice tmp_slice(&allocation, kTmpOffset, kTmpSizeBytes); + BufferAllocation::Slice f32_slice(&allocation, kInputOffset, kInputSizeBytes, PrimitiveType::F32); - BufferAllocation::Slice bf16_slice(&allocation, kLogSizeBytes, - kInputSizeBytes, PrimitiveType::BF16); + BufferAllocation::Slice bf16_slice(&allocation, kInputOffset, kInputSizeBytes, + PrimitiveType::BF16); Thunk::ThunkInfo checked_thunk_info; checked_thunk_info.thunk_id = ThunkId(123); BuffersDebugFloatCheckThunk thunk( - Thunk::ThunkInfo(), checked_thunk_info, log_slice, + Thunk::ThunkInfo(), checked_thunk_info, log_slice, tmp_slice, {{/*buffer_idx=*/0, f32_slice}, {/*buffer_idx=*/1, bf16_slice}}, std::make_shared()); diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_float_check.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_float_check.cc index 8241084f52831c..808142a0e2a45c 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_float_check.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_float_check.cc @@ -15,7 +15,10 @@ limitations under the License. #include "xla/backends/gpu/runtime/thunk_buffer_debug_float_check.h" +#include +#include #include +#include #include #include #include @@ -32,6 +35,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "Eigen/Core" #include "xla/backends/gpu/ffi.h" #include "xla/backends/gpu/runtime/buffer_debug_log.pb.h" #include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h" @@ -59,6 +63,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" +#include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -72,10 +77,41 @@ constexpr size_t kLogSizeBytes = 64 * 1024; namespace { -std::unique_ptr WrapWithFloatCheckThunk( +size_t CalculateTempBufferSize(const Thunk& thunk) { + size_t max_buffer_size_bytes = 0; + for (const BufferUse& use : thunk.buffer_uses()) { + if (use.HasDefinedContentsOnInput() || use.HasDefinedContentsOnOutput()) { + max_buffer_size_bytes = + std::max(max_buffer_size_bytes, use.slice().size()); + } + } + + // We're doing the float checks in 2 steps: + // - parallel aggregation: one thread block writes partial result into the + // temp buffer. The number of thread blocks used will be limtied by the size + // calculated here. + // - reduction of the temp buffer on a single thread block + // To optimize for time, we want to do as much computation in parallel as we + // can, but also consider the overhead of single-block reduction step. + + // Avoid making the reduction step use less than a block's worth of data. We + // can't go any faster than that anyway. + static constexpr size_t kMinElements = 1024; + // Arbitrary limit of 1Mi elements. This should be enough to accomodate the + // max number of thread blocks available on any supported GPU. + static constexpr size_t kMaxElements = 1024 * 1024; + const size_t size_elems = + xla::CeilOfRatio(max_buffer_size_bytes, sizeof(uint32_t)); + const size_t sqrt_size_elems = std::sqrt(size_elems); + return std::clamp(xla::CeilOfRatio(size_elems, sqrt_size_elems), kMinElements, + kMaxElements); +} + +absl::StatusOr> WrapWithFloatCheckThunk( std::unique_ptr thunk, BufferAllocation::Slice log_slice, const Thunk& predecessor_thunk, Thunk& successor_thunk, - std::shared_ptr metadata_store) { + std::shared_ptr metadata_store, + ThunkPassBufferAllocator& allocator) { const auto& thunk_buffers = thunk->buffer_uses(); if (thunk_buffers.empty()) { VLOG(1) << "No buffers in thunk " << thunk->thunk_info().thunk_id @@ -120,6 +156,12 @@ std::unique_ptr WrapWithFloatCheckThunk( return thunk; } + const size_t temp_buffer_size_bytes = + CalculateTempBufferSize(*thunk) * sizeof(xla::gpu::FloatCheckResult); + TF_ASSIGN_OR_RETURN(BufferAllocation * tmp_alloc, + allocator.NewEmptyAllocation(temp_buffer_size_bytes)); + BufferAllocation::Slice tmp_slice(tmp_alloc, 0, tmp_alloc->size()); + VLOG(1) << "Wrapping thunk " << thunk->thunk_info().thunk_id << " with float check thunk due to presence of buffers: " << buffers_to_check.size(); @@ -128,7 +170,7 @@ std::unique_ptr WrapWithFloatCheckThunk( thunk_and_checks.push_back(std::move(thunk)); auto buffer_debug_float_check_thunk = std::make_unique( - Thunk::ThunkInfo(), thunk_ptr->thunk_info(), log_slice, + Thunk::ThunkInfo(), thunk_ptr->thunk_info(), log_slice, tmp_slice, std::move(buffers_to_check), std::move(metadata_store)); buffer_debug_float_check_thunk->add_control_predecessor(thunk_ptr); thunk_and_checks.push_back(std::move(buffer_debug_float_check_thunk)); @@ -329,8 +371,9 @@ absl::Status RunFloatCheckPassInternal(SequentialThunk* root_thunk, CreateBufferDebugFloatCheckThunk(metadata_store, log_slice, hlo_module)); ThunkFilter thunk_filter = CreateThunkFilter(debug_options); - TF_RETURN_IF_ERROR( - root_thunk->TransformAllNestedThunks([&](std::unique_ptr thunk) { + TF_RETURN_IF_ERROR(root_thunk->TransformAllNestedThunks( + [&](std::unique_ptr thunk) + -> absl::StatusOr> { if (thunk_filter(*thunk) == InstrumentAction::kSkip) { return thunk; } @@ -338,7 +381,8 @@ absl::Status RunFloatCheckPassInternal(SequentialThunk* root_thunk, return WrapWithFloatCheckThunk( std::move(thunk), log_slice, /*predecessor_thunk=*/*buffer_debug_init_thunk, - /*successor_thunk=*/*buffer_debug_dump_thunk, metadata_store); + /*successor_thunk=*/*buffer_debug_dump_thunk, metadata_store, + allocator); })); ThunkSequence& thunks = root_thunk->thunks(); diff --git a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc index cb4e449e4a8bb1..62bc737997450a 100644 --- a/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/thunk_buffer_debug_pass_test.cc @@ -58,6 +58,7 @@ namespace { using testing::ElementsAre; using testing::Eq; +using testing::IsEmpty; using testing::Pair; using testing::Pointer; using testing::SizeIs; @@ -102,17 +103,16 @@ using SliceList = class FakeThunkPassBufferAllocator : public ThunkPassBufferAllocator { public: absl::StatusOr NewEmptyAllocation(int64_t size) override { - if (CreatedAlloc()) { - return absl::InvalidArgumentError("Expected only one allocation"); - } - alloc_ = std::make_unique(0, size, 0); - return alloc_.get(); + allocs_.push_back(std::make_unique(0, size, 0)); + return allocs_.back().get(); } - bool CreatedAlloc() { return alloc_ != nullptr; } + const std::vector>& allocs() const { + return allocs_; + } private: - std::unique_ptr alloc_; + std::vector> allocs_; }; class FakeThunk : public Thunk { @@ -188,6 +188,7 @@ TEST_F(ThunkBufferDebugPassTest, IsNoOpWhenHloModuleIsNull) { /*hlo_module=*/nullptr, device_info, allocator)); EXPECT_FALSE(changed); EXPECT_THAT(root_thunk->thunks(), ElementsAre(Pointer(fake_thunk_ptr))); + EXPECT_THAT(allocator.allocs(), IsEmpty()); } TEST_F(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) { @@ -256,6 +257,8 @@ TEST_F(ThunkBufferDebugPassTest, InsertsBuffersDebugChecksumThunks) { {2, slice_io}, }))), IsCustomCallThunkWithTargetName("xla_gpu_buffer_debug_log_dump"))); + + EXPECT_THAT(allocator.allocs(), SizeIs(1)); } TEST_F(ThunkBufferDebugPassTest, RecursivelyInsertsBuffersDebugChecksumThunks) { @@ -461,6 +464,8 @@ TEST_F(ThunkBufferDebugPassTest, RecursivelyInsertsBuffersDebugChecksumThunks) { Pointer(branch1_thunk_ptr), IsChecksumThunkChecking(SliceList{{0, slice_branch1}}))); } + + EXPECT_THAT(allocator.allocs(), SizeIs(1)); } TEST_F(ThunkBufferDebugPassTest, InsertsBuffersDebugFloatCheckThunks) { @@ -544,6 +549,9 @@ TEST_F(ThunkBufferDebugPassTest, InsertsBuffersDebugFloatCheckThunks) { static_cast(*sub_thunks[1]); EXPECT_THAT(buffer_debug_after_fake_thunk.buffer_slices(), UnorderedElementsAre(Pair(1, slice_o), Pair(2, slice_io))); + + // 1 for the log buffer, 1 per wrapped thunk for the temp buffer + EXPECT_THAT(allocator.allocs(), SizeIs(2)); } TEST_F(ThunkBufferDebugPassTest, BufferSaverInserter) { diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index ba0403fbed0832..1342a13898784c 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -453,7 +453,9 @@ cuda_library( "gpu", ], deps = [ + ":cuda_platform", ":cuda_platform_id", + "//xla:util", "//xla/backends/gpu/runtime:buffer_debug_log_structs", "//xla/stream_executor:kernel_spec", "//xla/stream_executor/gpu:buffer_debug_float_check_kernel", @@ -475,6 +477,7 @@ xla_test( "//xla/backends/gpu/runtime:buffer_debug_log_structs", "//xla/backends/gpu/runtime:thunk_id", "//xla/stream_executor:device_address", + "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", diff --git a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda.cu.cc index 2325478c1256fb..4f6e94ab7ccce2 100644 --- a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda.cu.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include #include +#include +#include #include "absl/base/casts.h" #include "third_party/gpus/cuda/include/cuda/atomic" @@ -24,11 +29,29 @@ limitations under the License. #include "xla/stream_executor/gpu/buffer_debug_float_check_kernel.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" #include "xla/stream_executor/kernel_spec.h" +#include "xla/util.h" namespace se = stream_executor; namespace { +using xla::gpu::FloatCheckResult; + +// https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/: +// > CUDA architecture limits the numbers of threads per block (1024 threads +// > per block limit). +static constexpr uint64_t kBlockSize = 1024; +// warpSize is not a compile time constant on all OSS CI builds, but we need it +// to be one for static array initialization. We assert this value matches +// warpSize at runtime. +static constexpr uint64_t kWarpSize = 32; +static constexpr uint64_t kMaxWarpsPerBlock = kBlockSize / kWarpSize; +template +static constexpr uint64_t kElementsPerMemoryAccess = + std::max(16 / sizeof(T), 1); +template +using Chunk = std::array>; + __device__ unsigned int ThreadIdx() { return threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x; @@ -39,16 +62,57 @@ __device__ unsigned int BlockIdx() { blockIdx.x; } -// Based on -// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf -template -__device__ void WarpReduceSum(unsigned int tid, volatile uint32_t* data) { - if (BLOCK_SIZE >= 64) data[tid] += data[tid + 32]; - if (BLOCK_SIZE >= 32) data[tid] += data[tid + 16]; - if (BLOCK_SIZE >= 16) data[tid] += data[tid + 8]; - if (BLOCK_SIZE >= 8) data[tid] += data[tid + 4]; - if (BLOCK_SIZE >= 4) data[tid] += data[tid + 2]; - if (BLOCK_SIZE >= 2) data[tid] += data[tid + 1]; +// Reduce a warp worth of values into a single one and have the 0th thread in +// the warp return it. +__device__ uint32_t WarpReduceSum(uint32_t value) { + static constexpr uint32_t kFullMask = ~0; + for (unsigned int offset = 1; offset < kWarpSize; offset <<= 1) { + value += __shfl_down_sync(kFullMask, value, offset); + } + return value; +} + +// Sum up a block worth of FloatCheckResults into a single one and have the 0th +// thread in the block return it. +__device__ FloatCheckResult BlockReduceSum(uint32_t tid, + FloatCheckResult value) { + assert(kWarpSize == warpSize); + static_assert(kBlockSize == kWarpSize * kMaxWarpsPerBlock); + // Required to do the second warp reduction. + static_assert(kMaxWarpsPerBlock == kWarpSize); + + const size_t warp_idx = tid / kWarpSize; + const size_t lane_idx = tid % kWarpSize; + + value.nan_count = WarpReduceSum(value.nan_count); + value.inf_count = WarpReduceSum(value.inf_count); + value.zero_count = WarpReduceSum(value.zero_count); + + __shared__ uint32_t scratch_nan[kMaxWarpsPerBlock]; + __shared__ uint32_t scratch_inf[kMaxWarpsPerBlock]; + __shared__ uint32_t scratch_zero[kMaxWarpsPerBlock]; + if (lane_idx == 0) { + scratch_nan[warp_idx] = value.nan_count; + scratch_inf[warp_idx] = value.inf_count; + scratch_zero[warp_idx] = value.zero_count; + } + + __syncthreads(); + // The first warp reduces the results from all warps. + if (warp_idx == 0) { + value.nan_count = scratch_nan[lane_idx]; + value.inf_count = scratch_inf[lane_idx]; + value.zero_count = scratch_zero[lane_idx]; + value.nan_count = WarpReduceSum(value.nan_count); + value.inf_count = WarpReduceSum(value.inf_count); + value.zero_count = WarpReduceSum(value.zero_count); + } else { + value.nan_count = 0; + value.inf_count = 0; + value.zero_count = 0; + } + + return value; } __device__ inline bool IsNan(float v) { return isnan(v); } @@ -60,173 +124,126 @@ __device__ inline bool IsZero(__nv_bfloat16 v) { return v == __nv_bfloat16(0.0f); } -// Calculates count of NaNs of all elements of `input` and puts result in -// `output`. -// -// Optimized implementation based on -// https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf -// that takes advantage of `BLOCK_SIZE` threads. -// -// `BLOCK_SIZE` must be a power of 2 no larger than 1024. -template -__device__ void ReduceSum(const T* input, uint64_t input_size, - uint32_t* nan_counter, uint32_t* inf_counter, - uint32_t* zero_counter) { - __shared__ uint32_t nan_count[BLOCK_SIZE]; - __shared__ uint32_t inf_count[BLOCK_SIZE]; - __shared__ uint32_t zero_count[BLOCK_SIZE]; +// Get a part of the input buffer current thread block is responsible for +// processing, assuming the load is spread up to max_blocks across the entire +// grid. If max_blocks is not provided, the entire grid is used. +template +__device__ inline std::tuple GetBlockInput( + const T* input, uint64_t input_size, + std::optional max_blocks = std::nullopt) { + size_t grid_size = gridDim.x * gridDim.y * gridDim.z; + if (max_blocks.has_value()) { + grid_size = std::min(grid_size, *max_blocks); + } + const uint64_t max_block_input_size = xla::RoundUpTo( + xla::CeilOfRatio(input_size, grid_size), kElementsPerMemoryAccess); + const uint64_t block_input_offset = BlockIdx() * max_block_input_size; + const uint64_t block_input_size = + std::min(max_block_input_size, input_size - block_input_offset); + return {input + block_input_offset, block_input_size}; +} - assert(BlockIdx() == 0); +template +__device__ FloatCheckResult CheckFloats(const T* input, uint64_t input_size, + uint64_t max_blocks) { const unsigned int tid = ThreadIdx(); + const auto [block_input, block_input_size] = + GetBlockInput(input, input_size, max_blocks); - nan_count[tid] = 0; - inf_count[tid] = 0; - zero_count[tid] = 0; - for (unsigned int i = tid; i < input_size; i += BLOCK_SIZE) { - if (IsNan(input[i])) { - nan_count[tid]++; - } - if (IsInf(input[i])) { - inf_count[tid]++; - } - if (IsZero(input[i])) { - zero_count[tid]++; - } - } - - __syncthreads(); + const Chunk* chunked_input = + reinterpret_cast*>(block_input); + const uint64_t input_chunks = + xla::FloorOfRatio(block_input_size, kElementsPerMemoryAccess); + // This may be less than block_input_size only for the last block. + const uint64_t chunked_input_size = + xla::RoundDownTo(block_input_size, kElementsPerMemoryAccess); - if (BLOCK_SIZE >= 1024) { - if (tid < 512) { - nan_count[tid] += nan_count[tid + 512]; - inf_count[tid] += inf_count[tid + 512]; - zero_count[tid] += zero_count[tid + 512]; + FloatCheckResult result{}; + for (uint64_t i = tid; i < input_chunks; i += kBlockSize) { + Chunk values = chunked_input[i]; + for (const T value : values) { + result.nan_count += IsNan(value); + result.inf_count += IsInf(value); + result.zero_count += IsZero(value); } - __syncthreads(); } - if (BLOCK_SIZE >= 512) { - if (tid < 256) { - nan_count[tid] += nan_count[tid + 256]; - inf_count[tid] += inf_count[tid + 256]; - zero_count[tid] += zero_count[tid + 256]; - } - __syncthreads(); - } - if (BLOCK_SIZE >= 256) { - if (tid < 128) { - nan_count[tid] += nan_count[tid + 128]; - inf_count[tid] += inf_count[tid + 128]; - zero_count[tid] += zero_count[tid + 128]; - } - __syncthreads(); - } - if (BLOCK_SIZE >= 128) { - if (tid < 64) { - nan_count[tid] += nan_count[tid + 64]; - inf_count[tid] += inf_count[tid + 64]; - zero_count[tid] += zero_count[tid + 64]; + + if (tid == 0 && chunked_input_size < block_input_size) { + const size_t rest = block_input_size - chunked_input_size; + for (uint64_t j = 0; j < rest; ++j) { + const T value = block_input[input_chunks + j]; + result.nan_count += IsNan(value); + result.inf_count += IsInf(value); + result.zero_count += IsZero(value); } - __syncthreads(); - } - if (tid < 32) { - WarpReduceSum(tid, nan_count); - WarpReduceSum(tid, inf_count); - WarpReduceSum(tid, zero_count); } - if (tid == 0) { - *nan_counter = nan_count[0]; - *inf_counter = inf_count[0]; - *zero_counter = zero_count[0]; + + return BlockReduceSum(tid, result); +} + +__device__ FloatCheckResult ReduceResults(const FloatCheckResult* input, + uint64_t input_size) { + const unsigned int tid = ThreadIdx(); + const auto [block_input, block_input_size] = GetBlockInput(input, input_size); + + FloatCheckResult result{}; + for (uint64_t i = tid; i < input_size; i += kBlockSize) { + const FloatCheckResult value = block_input[i]; + result.nan_count += value.nan_count; + result.inf_count += value.inf_count; + result.zero_count += value.zero_count; } + + // Now reduce a block worth of values into a single one. + return BlockReduceSum(tid, result); } -// Attempts to append the NaN count of the `input` buffer to the -// `float_check_entries`, using `log_header` to track available capacity and -// used space. -// -// The log entry is tagged with `entry_id`. The NaN count is parallelized as -// much as block dimensions allow it. -// -// If the log does not have enough space for the new entry, the entry is -// discarded. -// -// `input_size_in_bytes` is the size of the input buffer in bytes. -// -// LIMITATIONS: -// - Only a single thread block is supported. -// - Block dimensions must be a power of 2. +// Count the number of floats for NaNs, Infs and zeros in input buffer and store +// partially accumulated results in the tmp array. template -__global__ void AppendFloatCheck( - xla::gpu::BufferDebugLogEntryId entry_id, const T* input, - uint64_t input_size_in_bytes, xla::gpu::BufferDebugLogHeader* log_header, - xla::gpu::BufferDebugFloatCheckEntry* float_check_entries) { - const uint32_t block_size = blockDim.x * blockDim.y * blockDim.z; - const uint64_t input_size = input_size_in_bytes / sizeof(T); - uint32_t nan_count = 0; - uint32_t inf_count = 0; - uint32_t zero_count = 0; - - assert(gridDim.x == 1 && gridDim.y == 1 && gridDim.z == 1); - if (BlockIdx() != 0) { +__global__ void FloatCheck(const T* input, uint64_t input_size, + xla::gpu::FloatCheckResult* tmp, uint64_t tmp_size) { + assert(blockDim.x * blockDim.y * blockDim.z == kBlockSize); + assert(BlockIdx() < tmp_size); + if (BlockIdx() >= tmp_size) { return; } - // https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/: - // > CUDA architecture limits the numbers of threads per block (1024 threads - // > per block limit). - switch (block_size) { - case 1024: - ReduceSum(input, input_size, &nan_count, &inf_count, - &zero_count); - break; - case 512: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 256: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 128: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 64: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 32: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 16: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 8: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 4: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 2: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - case 1: - ReduceSum(input, input_size, &nan_count, &inf_count, &zero_count); - break; - default: - // Unsupported block size. - assert(false); - return; + const FloatCheckResult result = CheckFloats(input, input_size, tmp_size); + if (ThreadIdx() == 0) { + tmp[BlockIdx()] = result; } +} - if (ThreadIdx() == 0) { - cuda::atomic_ref - nan_count_log_write_idx(log_header->write_idx); +// Reduce the partially accumulated results from `FloatCheck` invocations and +// append the result to the buffer debug log. +__global__ void ReduceFloatCheckResults( + xla::gpu::FloatCheckResult* tmp, uint64_t tmp_size, + xla::gpu::BufferDebugLogEntryId entry_id, + xla::gpu::BufferDebugLogHeader* log_header, + xla::gpu::BufferDebugFloatCheckEntry* log_entries) { + assert(blockDim.x * blockDim.y * blockDim.z == kBlockSize); + assert(BlockIdx() == 0); + if (BlockIdx() >= 1) { + return; + } + + assert(tmp_size > 0); + FloatCheckResult total = ReduceResults(tmp, tmp_size); + + if (BlockIdx() == 0 && ThreadIdx() == 0) { + cuda::atomic_ref log_write_idx( + log_header->write_idx); #if __CUDA_ARCH__ >= 600 - const uint32_t write_idx = nan_count_log_write_idx.fetch_add(1); - if (nan_count_log_write_idx.load() < log_header->capacity) { - float_check_entries[write_idx] = xla::gpu::BufferDebugFloatCheckEntry{ - entry_id, nan_count, inf_count, zero_count}; + const uint32_t write_idx = log_write_idx.fetch_add(1); + if (write_idx < log_header->capacity) { + log_entries[write_idx] = xla::gpu::BufferDebugFloatCheckEntry{ + entry_id, total.nan_count, total.inf_count, total.zero_count}; } #else // Our toolchains generate a fetch_add PTX instructions with system scope, // which is not supported on pre-Pascal architectures. + (void)total; assert(false); #endif } @@ -234,16 +251,22 @@ __global__ void AppendFloatCheck( se::KernelLoaderSpec GetFloatCheckF32KernelSpec(int arity) { return se::KernelLoaderSpec::CreateInProcessSymbolSpec( - absl::bit_cast(&AppendFloatCheck), + absl::bit_cast(&FloatCheck), "BufferDebugFloatCheckF32Kernel", arity); } se::KernelLoaderSpec GetFloatCheckBf16KernelSpec(int arity) { return se::KernelLoaderSpec::CreateInProcessSymbolSpec( - absl::bit_cast(&AppendFloatCheck<__nv_bfloat16>), + absl::bit_cast(&FloatCheck<__nv_bfloat16>), "BufferDebugFloatCheckBf16Kernel", arity); } +se::KernelLoaderSpec GetReduceFloatCheckResultsKernelSpec(int arity) { + return se::KernelLoaderSpec::CreateInProcessSymbolSpec( + absl::bit_cast(&ReduceFloatCheckResults), + "BufferDebugReduceFloatCheckResultsKernel", arity); +} + } // namespace GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( @@ -253,3 +276,8 @@ GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( BufferDebugFloatCheckBf16Kernel, se::gpu::BufferDebugFloatCheckBf16Kernel, se::cuda::kCudaPlatformId, GetFloatCheckBf16KernelSpec); + +GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY( + BufferDebugReduceFloatCheckResultsKernel, + se::gpu::BufferDebugAppendReducedFloatCheckResultsKernel, + se::cuda::kCudaPlatformId, GetReduceFloatCheckResultsKernelSpec); diff --git a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc index 56ec5d18289bed..a1ab9cbb610482 100644 --- a/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/buffer_debug_float_check_kernel_cuda_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include #include @@ -29,6 +31,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/buffer_debug_log_structs.h" #include "xla/backends/gpu/runtime/thunk_id.h" #include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/buffer_debug_float_check_kernel.h" #include "xla/stream_executor/gpu/buffer_debug_log.h" #include "xla/stream_executor/gpu/gpu_kernel_registry.h" @@ -86,11 +89,17 @@ class FloatCheckKernelTest : public ::testing::Test { absl::Status AppendFloatCheckOnDevice( BufferDebugLogEntryId entry_id, const std::vector& input, se::gpu::BufferDebugLog& buffer_debug_log, - stream_executor::ThreadDim dim = stream_executor::ThreadDim(1, 1, 1)) { + stream_executor::BlockDim block_dim = stream_executor::BlockDim(1, 1, 1), + size_t temp_buffer_size_elements = 1024) { // Load kernel gpu::GpuKernelRegistry registry = gpu::GpuKernelRegistry::GetGlobalRegistry(); TF_ASSIGN_OR_RETURN(auto kernel, registry.LoadKernel(executor_)); + TF_ASSIGN_OR_RETURN( + auto reduce_kernel, + registry + .LoadKernel( + executor_)); // Setup device buffers TF_ASSIGN_OR_RETURN( @@ -100,13 +109,27 @@ class FloatCheckKernelTest : public ::testing::Test { auto cleanup_input = absl::MakeCleanup([&]() { executor_->Deallocate(&device_input); }); + TF_ASSIGN_OR_RETURN( + se::DeviceAddress device_tmp, + CheckNotNull(executor_->AllocateArray( + temp_buffer_size_elements), + "tmp")); + auto cleanup_tmp = + absl::MakeCleanup([&]() { executor_->Deallocate(&device_tmp); }); + + const se::ThreadDim thread_dim(1024, 1, 1); + // Call kernel TF_RETURN_IF_ERROR(stream_->Memcpy(&device_input, input.data(), input.size() * sizeof(input[0]))); - TF_RETURN_IF_ERROR(kernel.Launch( - dim, stream_executor::BlockDim(1, 1, 1), stream_.get(), entry_id, - device_input, device_input.ElementCount() * sizeof(InputType), - buffer_debug_log.GetDeviceHeader(), + TF_RETURN_IF_ERROR(kernel.Launch(thread_dim, block_dim, stream_.get(), + device_input, device_input.ElementCount(), + device_tmp, device_tmp.ElementCount())); + TF_RETURN_IF_ERROR(reduce_kernel.Launch( + thread_dim, se::BlockDim(1, 1, 1), stream_.get(), device_tmp, + std::min(device_tmp.ElementCount(), + block_dim.x * block_dim.y * block_dim.z), + entry_id, buffer_debug_log.GetDeviceHeader(), buffer_debug_log.GetDeviceEntries())); TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); @@ -170,33 +193,101 @@ TEST_F(FloatCheckKernelTest, ChecksFloatsForBf16) { } TEST_F(FloatCheckKernelTest, ChecksFloatsInParallel) { - se::DeviceAddress mem = executor_->AllocateArray(1024); - std::vector input(1024, 1.0f); - input[100] = std::numeric_limits::quiet_NaN(); - input[200] = std::numeric_limits::quiet_NaN(); - input[300] = std::numeric_limits::quiet_NaN(); - input[400] = 0.0f; - input[600] = std::numeric_limits::infinity(); - input[700] = std::numeric_limits::infinity(); + static constexpr size_t kNumNaNs = 100; + static constexpr size_t kNumInfs = 200; + static constexpr size_t kNumZeros = 300; + static constexpr size_t kMaxTestValues = + std::max(std::max(kNumNaNs, kNumInfs), kNumZeros); + + const se::DeviceDescription& device_desc = executor_->GetDeviceDescription(); + const size_t threads_per_core = device_desc.threads_per_core_limit(); + const size_t num_cores = device_desc.core_count(); + const size_t input_size = num_cores * threads_per_core * 3 / 2; + const size_t test_value_stride = input_size / (kMaxTestValues + 1); + ASSERT_GT(input_size, kMaxTestValues); + ASSERT_GT(test_value_stride, 2); + + std::vector input(input_size, 1.0f); + for (size_t i = 0; i < kNumNaNs; ++i) { + input[i * test_value_stride] = std::numeric_limits::quiet_NaN(); + } + for (size_t i = 0; i < kNumInfs; ++i) { + input[i * test_value_stride + 1] = std::numeric_limits::infinity(); + } + for (size_t i = 0; i < kNumZeros; ++i) { + input[i * test_value_stride + 2] = 0.0f; + } + se::DeviceAddress log_mem = executor_->AllocateArray(1024); TF_ASSERT_OK_AND_ASSIGN( auto device_log, se::gpu::BufferDebugLog::CreateOnDevice( - *stream_, mem)); + *stream_, log_mem)); + int64_t threads_per_block; + int64_t num_blocks; + CalculateDimensionality(executor_->GetDeviceDescription(), input.size(), + &threads_per_block, &num_blocks); + const se::BlockDim block_dim(num_blocks); TF_EXPECT_OK(AppendFloatCheckOnDevice( - BufferDebugLogEntryId{0}, input, device_log, se::ThreadDim(2, 4, 8))); + BufferDebugLogEntryId{0}, input, device_log, block_dim)); TF_EXPECT_OK(AppendFloatCheckOnDevice( - BufferDebugLogEntryId{0}, input, device_log, se::ThreadDim(2, 4, 8))); + BufferDebugLogEntryId{0}, input, device_log, block_dim)); TF_ASSERT_OK_AND_ASSIGN(auto host_log, device_log.ReadFromDevice(*stream_)); ASSERT_GE(host_log.size(), 2); - EXPECT_EQ(host_log[0].nan_count, 3); - EXPECT_EQ(host_log[0].inf_count, 2); - EXPECT_EQ(host_log[0].zero_count, 1); - EXPECT_EQ(host_log[1].nan_count, 3); - EXPECT_EQ(host_log[1].inf_count, 2); - EXPECT_EQ(host_log[1].zero_count, 1); + EXPECT_EQ(host_log[0].nan_count, kNumNaNs); + EXPECT_EQ(host_log[0].inf_count, kNumInfs); + EXPECT_EQ(host_log[0].zero_count, kNumZeros); + EXPECT_EQ(host_log[1].nan_count, kNumNaNs); + EXPECT_EQ(host_log[1].inf_count, kNumInfs); + EXPECT_EQ(host_log[1].zero_count, kNumZeros); +} + +TEST_F(FloatCheckKernelTest, ReduceFloatCheckResults) { + static constexpr size_t kNumNaNs = 100; + static constexpr size_t kNumInfs = 200; + static constexpr size_t kNumZeros = 300; + static constexpr size_t kIntermediateResults = 16 * 1024; + + std::vector results(kIntermediateResults); + for (size_t i = 0; i < kIntermediateResults; ++i) { + results[i].nan_count = i < kNumNaNs ? 1 : 0; + results[i].inf_count = i < kNumInfs ? 1 : 0; + results[i].zero_count = i < kNumZeros ? 1 : 0; + } + + gpu::GpuKernelRegistry registry = gpu::GpuKernelRegistry::GetGlobalRegistry(); + TF_ASSERT_OK_AND_ASSIGN( + auto reduce_kernel, + registry.LoadKernel( + executor_)); + + se::DeviceAddress log_mem = executor_->AllocateArray(1024); + TF_ASSERT_OK_AND_ASSIGN( + auto device_log, + se::gpu::BufferDebugLog::CreateOnDevice( + *stream_, log_mem)); + TF_ASSERT_OK_AND_ASSIGN( + se::DeviceAddress device_results, + CheckNotNull(executor_->AllocateArray( + kIntermediateResults), + "results")); + auto cleanup_results = + absl::MakeCleanup([&]() { executor_->Deallocate(&device_results); }); + + TF_ASSERT_OK(stream_->Memcpy(&device_results, results.data(), + results.size() * sizeof(results[0]))); + TF_ASSERT_OK(reduce_kernel.Launch( + se::ThreadDim(1024, 1, 1), se::BlockDim(1, 1, 1), stream_.get(), + device_results, device_results.ElementCount(), BufferDebugLogEntryId{0}, + device_log.GetDeviceHeader(), device_log.GetDeviceEntries())); + TF_ASSERT_OK_AND_ASSIGN(auto host_log, device_log.ReadFromDevice(*stream_)); + + ASSERT_GE(host_log.size(), 1); + EXPECT_EQ(host_log[0].nan_count, kNumNaNs); + EXPECT_EQ(host_log[0].inf_count, kNumInfs); + EXPECT_EQ(host_log[0].zero_count, kNumZeros); } } // namespace diff --git a/third_party/xla/xla/stream_executor/gpu/buffer_debug_float_check_kernel.h b/third_party/xla/xla/stream_executor/gpu/buffer_debug_float_check_kernel.h index 421a1a08b7d547..af0b687d6f9578 100644 --- a/third_party/xla/xla/stream_executor/gpu/buffer_debug_float_check_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/buffer_debug_float_check_kernel.h @@ -25,21 +25,32 @@ limitations under the License. namespace stream_executor::gpu { -// Trait for a kernel that computes the NaN count of given input buffer and -// appends it to the buffer debug log. -// -// This kernel MUST execute on a single thread block. +// Counts the number of NaNs, Infs and zeros in a buffer of floats in parallel, +// and stores partially accumulated results in the FloatCheckResult array. struct BufferDebugFloatCheckF32Kernel { using KernelType = - TypedKernel, - uint64_t, DeviceAddress, - DeviceAddress>; + TypedKernel, uint64_t, + DeviceAddress, uint64_t>; }; +// Counts the number of NaNs, Infs and zeros in a buffer of bfloat16s in +// parallel, and stores partially accumulated results in the FloatCheckResult +// array. struct BufferDebugFloatCheckBf16Kernel { using KernelType = - TypedKernel, uint64_t, + TypedKernel, uint64_t, + DeviceAddress, uint64_t>; +}; + +// Trait for a kernel that reduces the partially accumulated results from +// `BufferDebugFloatCheckF32Kernel` or `BufferDebugFloatCheckBf16Kernel` +// invocations and appends the result to the buffer debug log. +// +// This kernel MUST execute on a single thread block. +struct BufferDebugAppendReducedFloatCheckResultsKernel { + using KernelType = + TypedKernel, uint64_t, + xla::gpu::BufferDebugLogEntryId, DeviceAddress, DeviceAddress>; }; From 9f90786e64afa218470d36ce3394756970042895 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 9 Dec 2025 10:14:58 -0800 Subject: [PATCH 087/753] [stream_executor] Make sure that DeviceAddress behaves like a pointer wrt comparison to nullptr_t and casting to bool PiperOrigin-RevId: 842292205 --- third_party/xla/xla/stream_executor/BUILD | 9 ++++ .../xla/xla/stream_executor/device_address.h | 50 +++++++++---------- .../stream_executor/device_address_test.cc | 40 +++++++++++++++ 3 files changed, 74 insertions(+), 25 deletions(-) create mode 100644 third_party/xla/xla/stream_executor/device_address_test.cc diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 195ec82c4d2777..5c1481cc227da1 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -82,6 +82,15 @@ cc_library( ], ) +xla_cc_test( + name = "device_address_test", + srcs = ["device_address_test.cc"], + deps = [ + ":device_address", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "device_address_handle", srcs = ["device_address_handle.cc"], diff --git a/third_party/xla/xla/stream_executor/device_address.h b/third_party/xla/xla/stream_executor/device_address.h index a2ac3d8ac02095..9884ade4430b7d 100644 --- a/third_party/xla/xla/stream_executor/device_address.h +++ b/third_party/xla/xla/stream_executor/device_address.h @@ -41,7 +41,7 @@ namespace stream_executor { // check for `opaque` being null to determine if the device address is null. class DeviceAddressBase { public: - // Default constructor instantiates a null-pointed, zero-sized device memory + // Default constructor instantiates a null-pointed, zero-sized device address // region. An opaque pointer may be provided -- see header for details on the // opacity of that pointer. explicit DeviceAddressBase(void* opaque = nullptr, uint64_t size = 0) @@ -53,10 +53,12 @@ class DeviceAddressBase { // explicit DeviceAddressBase(void *opaque) = delete; } - // Returns whether the backing memory is the null pointer. + // Returns whether the backing address is the null pointer. // A `== nullptr` convenience method is also provided. bool is_null() const { return opaque_ == nullptr; } + explicit operator bool() const { return !is_null(); } + bool operator==(std::nullptr_t other) const { return is_null(); } bool operator!=(std::nullptr_t other) const { return !is_null(); } @@ -64,7 +66,7 @@ class DeviceAddressBase { return opaque_ == other.opaque_ && size_ == other.size_; } - // Provides a partial order between device memory values. + // Provides a partial order between device address values. // // This operator is provided so that this object can be used as a key in an // ordered map. @@ -72,14 +74,14 @@ class DeviceAddressBase { return std::tie(opaque_, size_) < std::tie(other.opaque_, other.size_); } - // Returns the size, in bytes, for the backing memory. + // Returns the size, in bytes, for the backing address range. uint64_t size() const { return size_; } // Warning: note that the pointer returned is not necessarily directly to // device virtual address space, but is platform-dependent. void* opaque() const { return opaque_; } - // Returns the payload of this memory region. + // Returns the payload of this address range. uint64_t payload() const { return payload_; } // Sets payload to given value. @@ -91,60 +93,58 @@ class DeviceAddressBase { return opaque() == other.opaque() && size() == other.size(); } - // Creates a memory region (slice) inside another allocated memory region. - // Offset and size are in bytes. + // Creates and address range slice at the given offset and size. Offset and + // size are in bytes. ABSL_ATTRIBUTE_ALWAYS_INLINE DeviceAddressBase GetByteSlice(uint64_t offset_bytes, uint64_t size_bytes) const { DCHECK(offset_bytes + size_bytes <= size_) - << "requested slice allocation (offset + size) is greater " - << "than parent allocation size: (" << offset_bytes << " + " - << size_bytes << ") vs. (" << size_ << ")"; + << "requested address slice (offset + size) is out of bounds " + << "of parent address: (" << offset_bytes << " + " << size_bytes + << ") vs. (" << size_ << ")"; return DeviceAddressBase( reinterpret_cast(opaque_) + offset_bytes, size_bytes); } private: - void* opaque_; // Platform-dependent value representing addressable memory. - uint64_t size_; // Size in bytes of this allocation. - uint64_t payload_ = 0; // Payload data associated with this allocation. + void* opaque_; // Platform-dependent value representing base address. + uint64_t size_; // Size in bytes of this address range. + uint64_t payload_ = 0; // Payload data associated with this address. }; // Typed wrapper around "void *"-like DeviceAddressBase. // // For example, DeviceAddress is a simple wrapper around -// DeviceAddressBase that represents one or more integers in Device memory. +// DeviceAddressBase that represents one or more integers on Device. template class DeviceAddress final : public DeviceAddressBase { public: - // Default constructor instantiates a null-pointed, zero-sized memory region. + // Default constructor instantiates a null-pointed, zero-sized addess range. DeviceAddress() : DeviceAddressBase(nullptr, 0) {} explicit DeviceAddress(std::nullptr_t) : DeviceAddress() {} - // Typed device memory regions may be constructed from untyped device memory - // regions, this effectively amounts to a cast from a void*. + // Typed device address range may be constructed from untyped device address + // range, this effectively amounts to a cast from a void*. explicit DeviceAddress(const DeviceAddressBase& other) - : DeviceAddressBase(const_cast(other).opaque(), - other.size()) { + : DeviceAddressBase(other.opaque(), other.size()) { SetPayload(other.payload()); } - // Returns the number of elements of type T that constitute this - // allocation. + // Returns the number of elements of type T that constitute this address. uint64_t ElementCount() const { return size() / sizeof(T); } - // Returns pointer to the allocated data + // Returns a base pointer to the data. T* base() const { return reinterpret_cast(opaque()); } // Creates a typed area of DeviceAddress with a given opaque pointer and the - // quantity of bytes in the allocation. This function is broken out to + // quantity of bytes in the address range. This function is broken out to // distinguish bytes from an element count. static DeviceAddress MakeFromByteSize(void* opaque, uint64_t bytes) { return DeviceAddress(opaque, bytes); } - // Creates a memory region (slice) inside another allocated memory region. - // Offset and size are specified in terms of T elements. + // Creates and address range slice at the given offset and count. Offset and + // count are specified in terms of T elements. DeviceAddress GetSlice(uint64_t element_offset, uint64_t element_count) { return DeviceAddress( GetByteSlice(sizeof(T) * element_offset, sizeof(T) * element_count)); diff --git a/third_party/xla/xla/stream_executor/device_address_test.cc b/third_party/xla/xla/stream_executor/device_address_test.cc new file mode 100644 index 00000000000000..71acd21672215e --- /dev/null +++ b/third_party/xla/xla/stream_executor/device_address_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/device_address.h" + +#include + +#include + +namespace stream_executor { +namespace { + +TEST(DeviceAddressTest, NullptrComparisons) { + { + DeviceAddressBase null_ptr; + EXPECT_FALSE(null_ptr); + EXPECT_TRUE(null_ptr == nullptr); + } + + { + DeviceAddress null_ptr; + EXPECT_FALSE(null_ptr); + EXPECT_TRUE(null_ptr == nullptr); + } +} + +} // namespace +} // namespace stream_executor From 21040b5b652e63d81f77509a5ce8ab27a5b1d16c Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 9 Dec 2025 10:17:43 -0800 Subject: [PATCH 088/753] [xla] Migrate to se::DeviceMemoryAddress PiperOrigin-RevId: 842293237 --- tensorflow/compiler/jit/xla_tensor.cc | 2 +- .../runtime/gpublas_lt_matmul_thunk_test.cc | 3 +- third_party/xla/xla/client/BUILD | 4 +- third_party/xla/xla/client/client_library.h | 2 +- third_party/xla/xla/client/local_client.cc | 4 +- third_party/xla/xla/client/local_client.h | 4 +- third_party/xla/xla/core/collectives/BUILD | 2 +- .../xla/xla/core/collectives/communicator.h | 40 ++++++++-------- third_party/xla/xla/ffi/BUILD | 10 ++-- third_party/xla/xla/ffi/api/BUILD | 4 +- third_party/xla/xla/ffi/api/c_api_internal.h | 2 +- third_party/xla/xla/ffi/api/ffi_test.cc | 48 ++++++++++--------- third_party/xla/xla/ffi/call_frame.cc | 20 ++++---- third_party/xla/xla/ffi/call_frame.h | 14 +++--- third_party/xla/xla/ffi/call_frame_test.cc | 20 ++++---- third_party/xla/xla/ffi/ffi.h | 12 ++--- third_party/xla/xla/ffi/ffi_api.cc | 4 +- third_party/xla/xla/ffi/ffi_test.cc | 32 ++++++------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 4 +- .../xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc | 23 +++++---- .../xla/pjrt/pjrt_stream_executor_client.cc | 4 +- .../xla/pjrt/pjrt_stream_executor_client.h | 4 +- .../xla/pjrt/tracked_device_buffer_test.cc | 3 +- third_party/xla/xla/service/BUILD | 2 + .../xla/service/maybe_owning_device_memory.h | 2 + third_party/xla/xla/tests/BUILD | 16 +++---- .../xla/xla/tests/buffer_donation_test.cc | 12 ++--- .../xla/xla/tests/collective_ops_ffi_test.cc | 2 +- third_party/xla/xla/tests/hlo_test_base.cc | 4 +- third_party/xla/xla/tests/hlo_test_base.h | 6 +-- .../xla/tests/local_client_execute_test.cc | 2 +- .../xla/xla/tests/local_client_test_base.cc | 8 ++-- .../xla/xla/tests/local_client_test_base.h | 8 ++-- .../xla/xla/tests/transfer_manager_test.cc | 2 +- third_party/xla/xla/tools/BUILD | 2 +- 35 files changed, 173 insertions(+), 158 deletions(-) diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index e9cdad219dd28d..d6792cd7802d96 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -55,7 +55,7 @@ absl::Status XlaTensor::AllocateShapedBuffer(DataType dtype, xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, + TF_ASSIGN_OR_RETURN(se::ScopedDeviceAddress buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false, subshape.layout().memory_space())); diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc index ccdf653ca1862e..77a6ac88f8ff70 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include +#include #include #include #include @@ -182,7 +183,7 @@ class GpuBlasLtThunkBuilder { se::StreamExecutorMemoryAllocator allocator_; se::GpuComputeCapability gpu_comp_; std::deque allocs_; - std::vector mem_buffers_; + std::vector> mem_buffers_; }; void GpuBlasLtMatmulThunkTest::CreateExecuteThunksFromHLO( diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index c2801fa3fa8410..fac2d9343ff1d0 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -128,7 +128,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:source_map_util", "//xla/service:stream_pool", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -181,7 +181,7 @@ cc_library( "//xla/service:compile_only_service", "//xla/service:local_service", "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:logging", diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h index 0e4f3a9a24dd22..42d0f34202e092 100644 --- a/third_party/xla/xla/client/client_library.h +++ b/third_party/xla/xla/client/client_library.h @@ -36,7 +36,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/service/compile_only_service.h" #include "xla/service/local_service.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index cc383a9aa81b34..e1f348a755521d 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -45,7 +45,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" @@ -512,7 +512,7 @@ absl::StatusOr> LocalClient::LoadInternal( absl::StatusOr LocalClient::LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, - se::DeviceMemoryAllocator* allocator) { + se::DeviceAddressAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 3ccda5d43f6794..3c237ef37a1973 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -39,7 +39,7 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -183,7 +183,7 @@ class LocalClient : public Client { // device is used. absl::StatusOr LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, - se::DeviceMemoryAllocator* allocator = nullptr); + se::DeviceAddressAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. absl::StatusOr TransferToLocalServer( diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index 1b0398aaaf4801..06d3ef7f6c9aed 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -73,7 +73,7 @@ cc_library( "//xla:future", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/core/collectives/communicator.h b/third_party/xla/xla/core/collectives/communicator.h index 0f60a859db854d..4be35fb52163f7 100644 --- a/third_party/xla/xla/core/collectives/communicator.h +++ b/third_party/xla/xla/core/collectives/communicator.h @@ -28,7 +28,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/future.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -65,7 +65,7 @@ class Communicator { // Register `buffer_range` once for efficient collective operations (i.e. on // NCCL backend it registers the buffer for zero-copy collective operations). // - virtual absl::Status RegisterBufferOnce(se::DeviceMemoryBase buffer_range, + virtual absl::Status RegisterBufferOnce(se::DeviceAddressBase buffer_range, int device_ordinal, bool use_symmetric_buffer) { return Unimplemented("User-managed buffer registration is not supported"); @@ -91,40 +91,40 @@ class Communicator { // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. - virtual Future<> AllReduce(stream_executor::DeviceMemoryBase send_buffer, - stream_executor::DeviceMemoryBase recv_buffer, + virtual Future<> AllReduce(stream_executor::DeviceAddressBase send_buffer, + stream_executor::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, const Executor& executor) = 0; // Copy data in `send_buff` from the root device to the `recv_buff` on // all other devices. - virtual Future<> Broadcast(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, + virtual Future<> Broadcast(se::DeviceAddressBase send_buffer, + se::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, RankId root, const Executor& executor) = 0; // Reduce data in `send_buff` from all devices using the `reduction_kind` // operation and leave the reduced result scattered over the devices so that // the `recv_buff` on rank `i` will contain the i-th block of the result. - virtual Future<> ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, + virtual Future<> ReduceScatter(se::DeviceAddressBase send_buffer, + se::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, const Executor& executor) = 0; // Gather `count` values from all devices into `recv_buffer`, receiving data // from rank `i` at offset `i * sendcount`. - virtual Future<> AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, + virtual Future<> AllGather(se::DeviceAddressBase send_buffer, + se::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, const Executor& executor) = 0; // Sends data from `send_buffer` to `target_ranks` and receives data from // `source_rank` into `recv_buffer`. If `source_rank` is not specified, the // output is filled with zeros. - virtual Future<> CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, + virtual Future<> CollectivePermute(se::DeviceAddressBase send_buffer, + se::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, std::optional source_rank, absl::Span target_ranks, @@ -133,30 +133,30 @@ class Communicator { // Sends `count` values from `send_buffers` to other ranks and receives data // from other ranks into `recv_buffers`. virtual Future<> AllToAll( - absl::InlinedVector send_buffers, - absl::InlinedVector recv_buffers, + absl::InlinedVector send_buffers, + absl::InlinedVector recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) = 0; // Send data from `send_buff` to rank `peer`. - virtual Future<> Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + virtual Future<> Send(se::DeviceAddressBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; // Receive data from rank `peer` into `recv_buff`. - virtual Future<> Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + virtual Future<> Recv(se::DeviceAddressBase recv_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; // Send data from `send_buff` to rank `recv_buff` (one-way send). - virtual Future<> Send(se::DeviceMemoryBase recv_buffer, - se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + virtual Future<> Send(se::DeviceAddressBase recv_buffer, + se::DeviceAddressBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) { return Unimplemented("One-way send is not implemented"); } // Receive data from rank `peer` into `recv_buff` (one-way recv). - virtual Future<> Recv(se::DeviceMemoryBase recv_buffer, - se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + virtual Future<> Recv(se::DeviceAddressBase recv_buffer, + se::DeviceAddressBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) { return Unimplemented("One-way recv is not implemented"); } diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 41c825e3599ea2..f14764091594bc 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -37,7 +37,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -58,7 +58,7 @@ xla_cc_test( ":call_frame", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", @@ -149,7 +149,7 @@ cc_library( "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -182,8 +182,8 @@ cc_library( "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", "//xla/service:platform_util", + "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", - "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", @@ -299,7 +299,7 @@ xla_cc_test( "//xla/backends/cpu:ffi", "//xla/backends/gpu:ffi", "//xla/ffi/api:c_api", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 41889027b9ddd3..dc4551d8e2fecc 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -91,8 +91,8 @@ xla_cc_test( "//xla/ffi:execution_state", "//xla/ffi:ffi_api", "//xla/ffi:type_registry", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h index d0baf4fc3b7bb0..d9070080f3a4a6 100644 --- a/third_party/xla/xla/ffi/api/c_api_internal.h +++ b/third_party/xla/xla/ffi/api/c_api_internal.h @@ -93,7 +93,7 @@ typedef XLA_FFI_Error* XLA_FFI_INTERNAL_IntraOpThreadPool_Get( typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Stream_Get( XLA_FFI_ExecutionContext* ctx, void** stream); -// Returns a pointer to device memory allocator (`se::DeviceMemoryAllocator` +// Returns a pointer to device memory allocator (`se::DeviceAddressAllocator` // pointer) which allows to allocate memory inside a custom call from the same // allocator as XLA (i.e. it allows to construct scratch memory allocator). typedef XLA_FFI_Error* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get( diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index e3345ebe915146..81578f564956fd 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -45,8 +45,8 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/ffi/type_registry.h" #include "xla/primitive_util.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -522,7 +522,7 @@ TEST(FfiTest, DeviceOrdinal) { TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -544,7 +544,7 @@ TEST(FfiTest, AnyBufferArgument) { TEST(FfiTest, BufferArgument) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -562,7 +562,7 @@ TEST(FfiTest, BufferArgument) { TEST(FfiTest, AnyBufferResult) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -594,7 +594,7 @@ TEST(FfiTest, MissingBufferArgument) { TEST(FfiTest, WrongRankBufferArgument) { std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -611,7 +611,7 @@ TEST(FfiTest, WrongRankBufferArgument) { TEST(FfiTest, WrongTypeBufferArgument) { std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); @@ -648,7 +648,7 @@ TEST(FfiTest, WrongNumberOfArguments) { TEST(FfiTest, TokenArgument) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceMemoryBase(), PrimitiveType::TOKEN, + builder.AddBufferArg(se::DeviceAddressBase(), PrimitiveType::TOKEN, /*dims=*/{}); auto call_frame = builder.Build(); @@ -665,7 +665,7 @@ TEST(FfiTest, TokenArgument) { TEST(FfiTest, RemainingArgs) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -694,7 +694,7 @@ TEST(FfiTest, RemainingArgs) { TEST(FfiTest, RemainingRets) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -724,7 +724,7 @@ TEST(FfiTest, RemainingRets) { TEST(FfiTest, OptionalArgs) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -785,7 +785,7 @@ TEST(FfiTest, OptionalArgs) { TEST(FfiTest, OptionalRets) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -854,7 +854,7 @@ TEST(FfiTest, AutoBinding) { }); std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder::AttributesBuilder attrs; attrs.Insert(kI32, 42); @@ -873,7 +873,8 @@ TEST(FfiTest, AutoBindingResult) { Ffi::BindTo(+[](Result buffer) { return Error::Success(); }); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); - builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{}); + builder.AddBufferRet(se::DeviceAddressBase(), PrimitiveType::F32, + /*dims=*/{}); auto call_frame = builder.Build(); auto status = Call(*handler, call_frame); @@ -1409,19 +1410,22 @@ TEST(FfiTest, ScratchAllocator) { static void* kAddr = reinterpret_cast(0xDEADBEEF); // A test only memory allocator that returns a fixed memory address. - struct TestDeviceMemoryAllocator final : public se::DeviceMemoryAllocator { + struct TestDeviceMemoryAllocator final : public se::DeviceAddressAllocator { size_t count; TestDeviceMemoryAllocator() - : se::DeviceMemoryAllocator(nullptr), count(0) {} + : se::DeviceAddressAllocator(nullptr), count(0) {} - absl::StatusOr Allocate(int, uint64_t size, bool, - int64_t) final { + absl::StatusOr> Allocate(int, + uint64_t size, + bool, + int64_t) final { count++; - return se::OwningDeviceMemory(se::DeviceMemoryBase(kAddr, size), 0, this); + return se::ScopedDeviceAddress( + se::DeviceAddressBase(kAddr, size), 0, this); } - absl::Status Deallocate(int, se::DeviceMemoryBase mem) final { + absl::Status Deallocate(int, se::DeviceAddressBase mem) final { count--; EXPECT_EQ(mem.opaque(), kAddr); return absl::OkStatus(); @@ -1588,7 +1592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(BufferR2F32Handler, BufferR2F32Function); TEST(FfiTest, DefineAutoSymbol) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -1604,7 +1608,7 @@ TEST(FfiTest, DefineAutoSymbol) { //===----------------------------------------------------------------------===// static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { - se::DeviceMemoryBase memory; + se::DeviceAddressBase memory; std::vector dims(4, 1); CallFrameBuilder builder(/*num_args=*/num_args, /*num_rets=*/0); diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index ad7c71c98f8cd6..f0c17215c2dafd 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -35,7 +35,7 @@ limitations under the License. #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -47,7 +47,7 @@ namespace xla::ffi { //===----------------------------------------------------------------------===// struct CallFrameBuilder::Buffer { - se::DeviceMemoryBase memory; + se::DeviceAddressBase memory; PrimitiveType type; absl::InlinedVector dims; }; @@ -84,7 +84,7 @@ CallFrameBuilder::CallFrameBuilder(size_t num_args, size_t num_rets) { CallFrameBuilder::~CallFrameBuilder() = default; -void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory, +void CallFrameBuilder::AddBufferArg(se::DeviceAddressBase memory, PrimitiveType type, absl::Span dims) { DCHECK(args_.capacity() > args_.size()) @@ -95,10 +95,10 @@ void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory, void CallFrameBuilder::AddTokenArg() { DCHECK(args_.capacity() > args_.size()) << "CallFrame builder `num_args` argument was too small"; - args_.push_back(Buffer{se::DeviceMemoryBase(), PrimitiveType::TOKEN, {}}); + args_.push_back(Buffer{se::DeviceAddressBase(), PrimitiveType::TOKEN, {}}); } -void CallFrameBuilder::AddBufferRet(se::DeviceMemoryBase memory, +void CallFrameBuilder::AddBufferRet(se::DeviceAddressBase memory, PrimitiveType type, absl::Span dims) { DCHECK(rets_.capacity() > rets_.size()) @@ -109,7 +109,7 @@ void CallFrameBuilder::AddBufferRet(se::DeviceMemoryBase memory, void CallFrameBuilder::AddTokenRet() { DCHECK(rets_.capacity() > rets_.size()) << "CallFrame builder `num_rets` argument was too small"; - rets_.push_back(Buffer{se::DeviceMemoryBase(), PrimitiveType::TOKEN, {}}); + rets_.push_back(Buffer{se::DeviceAddressBase(), PrimitiveType::TOKEN, {}}); } void CallFrameBuilder::AddAttributes(AttributesMap attrs) { @@ -557,8 +557,8 @@ std::unique_ptr CallFrame::FixUpAttrs( //===----------------------------------------------------------------------===// absl::Status CallFrame::UpdateWithBuffers( - absl::Span args, - absl::Span rets) { + absl::Span args, + absl::Span rets) { if (ABSL_PREDICT_FALSE(args.size() != arguments_->args.size())) { return InvalidArgument("Invalid number of updated arguments: %d vs %d", args.size(), arguments_->args.size()); @@ -587,8 +587,8 @@ CallFrame CallFrame::Copy() const { } absl::StatusOr CallFrame::CopyWithBuffers( - absl::Span args, - absl::Span rets) const { + absl::Span args, + absl::Span rets) const { CallFrame clone(CopyArgs(*arguments_), CopyRets(*results_), attributes_); TF_RETURN_IF_ERROR(clone.UpdateWithBuffers(args, rets)); return clone; diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 32dceead1d9b4b..5433d4be990d42 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -30,7 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -76,12 +76,12 @@ class CallFrameBuilder { CallFrame Build(); - void AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type, + void AddBufferArg(se::DeviceAddressBase memory, PrimitiveType type, absl::Span dims); void AddTokenArg(); - void AddBufferRet(se::DeviceMemoryBase memory, PrimitiveType type, + void AddBufferRet(se::DeviceAddressBase memory, PrimitiveType type, absl::Span dims); void AddTokenRet(); @@ -117,16 +117,16 @@ class CallFrame { // array (buffer) arguments and results are known at compile time. Instead of // rebuilding the call frame from scratch on every execution, we can just // update the arguments and results with new pointers to device memory. - absl::Status UpdateWithBuffers(absl::Span args, - absl::Span rets); + absl::Status UpdateWithBuffers(absl::Span args, + absl::Span rets); // Creates a copy of the call frame. CallFrame Copy() const; // Creates a copy of the call frame with updated arguments and results. absl::StatusOr CopyWithBuffers( - absl::Span args, - absl::Span rets) const; + absl::Span args, + absl::Span rets) const; // Builds an XLA_FFI_CallFrame from owned arguments and attributes. XLA_FFI_CallFrame Build( diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index f73461fc7d297f..b58e2d9a2537b6 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" @@ -34,8 +34,8 @@ limitations under the License. namespace xla::ffi { TEST(CallFrameTest, UpdateCallFrame) { - se::DeviceMemoryBase mem0(reinterpret_cast(0x12345678), 1024); - se::DeviceMemoryBase mem1(reinterpret_cast(0x87654321), 1024); + se::DeviceAddressBase mem0(reinterpret_cast(0x12345678), 1024); + se::DeviceAddressBase mem1(reinterpret_cast(0x87654321), 1024); std::vector dims = {1, 2, 3, 4}; @@ -116,7 +116,7 @@ TEST(CallFrameTest, UpdateCallFrame) { void BM_AddBufferArg(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; for (auto _ : state) { @@ -151,17 +151,17 @@ void BM_AddAttributes(benchmark::State& state) { void BM_UpdateCallFrame(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; CallFrameBuilder builder(num_args, /*num_rets=*/0); for (size_t i = 0; i < num_args; ++i) { - builder.AddBufferArg(se::DeviceMemoryBase(nullptr, 1024), + builder.AddBufferArg(se::DeviceAddressBase(nullptr, 1024), PrimitiveType::F32, dims); } CallFrame call_frame = builder.Build(); - std::vector updated_args(num_args, memory); + std::vector updated_args(num_args, memory); for (auto _ : state) { auto updated_call_frame = @@ -173,17 +173,17 @@ void BM_UpdateCallFrame(benchmark::State& state) { void BM_UpdateCallFrameInPlace(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; CallFrameBuilder builder(num_args, /*num_rets=*/0); for (size_t i = 0; i < num_args; ++i) { - builder.AddBufferArg(se::DeviceMemoryBase(nullptr, 1024), + builder.AddBufferArg(se::DeviceAddressBase(nullptr, 1024), PrimitiveType::F32, dims); } CallFrame call_frame = builder.Build(); - std::vector updated_args(num_args, memory); + std::vector updated_args(num_args, memory); for (auto _ : state) { benchmark::DoNotOptimize( diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index da6303e14faef7..4e1849a190d327 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -50,7 +50,7 @@ limitations under the License. #include "xla/ffi/type_registry.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/primitive_util.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/types.h" // IWYU pragma: keep @@ -137,8 +137,8 @@ class AnyBuffer { return reinterpret_cast(buf_->data); } - se::DeviceMemoryBase device_memory() const { - return se::DeviceMemoryBase(untyped_data(), size_bytes()); + se::DeviceAddressBase device_memory() const { + return se::DeviceAddressBase(untyped_data(), size_bytes()); } private: @@ -182,9 +182,9 @@ class Buffer { return reinterpret_cast*>(untyped_data()); } - se::DeviceMemory> device_memory() const { - return se::DeviceMemory>( - se::DeviceMemoryBase(untyped_data(), size_bytes())); + se::DeviceAddress> device_memory() const { + return se::DeviceAddress>( + se::DeviceAddressBase(untyped_data(), size_bytes())); } private: diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 31287ac7587ef4..3f0de64033061e 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -47,8 +47,8 @@ limitations under the License. #include "xla/ffi/ffi_structs.h" #include "xla/ffi/type_registry.h" #include "xla/service/platform_util.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" -#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/tsl/platform/logging.h" @@ -795,7 +795,7 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free( absl::Status status = gpu->allocator->Deallocate( args->ctx->device_ordinal, - stream_executor::DeviceMemoryBase(args->data, args->size)); + stream_executor::DeviceAddressBase(args->data, args->size)); if (!status.ok()) { return new XLA_FFI_Error{std::move(status)}; } diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 8f0b00244c0a93..0369c8cc1946e5 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -43,7 +43,7 @@ limitations under the License. #include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" #include "xla/ffi/type_registry.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" @@ -179,7 +179,7 @@ TEST(FfiTest, CatchExceptionExplicit) { TEST(FfiTest, WrongNumArgs) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceMemoryBase(nullptr), PrimitiveType::F32, {}); + builder.AddBufferArg(se::DeviceAddressBase(nullptr), PrimitiveType::F32, {}); auto call_frame = builder.Build(); auto handler = Ffi::Bind().Arg().Arg().To( @@ -579,7 +579,7 @@ TEST(FfiTest, DecodingErrors) { TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -614,7 +614,7 @@ TEST(FfiTest, AnyBufferArgument) { TEST(FfiTest, TypedAndRankedBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), storage.size() * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), storage.size() * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -642,8 +642,8 @@ TEST(FfiTest, TypedAndRankedBufferArgument) { TEST(FfiTest, ComplexBufferArgument) { std::vector> storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), - storage.size() * sizeof(std::complex)); + se::DeviceAddressBase memory(storage.data(), + storage.size() * sizeof(std::complex)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::C64, /*dims=*/{2, 2}); @@ -662,7 +662,7 @@ TEST(FfiTest, ComplexBufferArgument) { TEST(FfiTest, TokenArgument) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceMemoryBase(), PrimitiveType::TOKEN, + builder.AddBufferArg(se::DeviceAddressBase(), PrimitiveType::TOKEN, /*dims=*/{}); auto call_frame = builder.Build(); @@ -679,7 +679,7 @@ TEST(FfiTest, TokenArgument) { TEST(FfiTest, WrongRankBufferArgument) { std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -697,7 +697,7 @@ TEST(FfiTest, WrongRankBufferArgument) { TEST(FfiTest, WrongTypeBufferArgument) { std::vector storage(4, 0.0); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); @@ -715,7 +715,7 @@ TEST(FfiTest, WrongTypeBufferArgument) { TEST(FfiTest, RemainingArgs) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -743,7 +743,7 @@ TEST(FfiTest, RemainingArgs) { TEST(FfiTest, RemainingRets) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -772,7 +772,7 @@ TEST(FfiTest, RemainingRets) { TEST(FfiTest, OptionalArgs) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -833,7 +833,7 @@ TEST(FfiTest, OptionalArgs) { TEST(FfiTest, OptionalRets) { std::vector storage(4, 0.0f); - se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -975,8 +975,8 @@ TEST(FfiTest, UpdateBufferArgumentsAndResults) { std::vector storage0(4, 0.0f); std::vector storage1(4, 0.0f); - se::DeviceMemoryBase memory0(storage0.data(), 4 * sizeof(float)); - se::DeviceMemoryBase memory1(storage1.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory0(storage0.data(), 4 * sizeof(float)); + se::DeviceAddressBase memory1(storage1.data(), 4 * sizeof(float)); std::vector dims = {2, 2}; @@ -1169,7 +1169,7 @@ TEST(FfiTest, PlatformStream) { //===----------------------------------------------------------------------===// static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { - se::DeviceMemoryBase memory; + se::DeviceAddressBase memory; std::vector dims(4, 1); CallFrameBuilder builder(/*num_args=*/num_args, /*num_rets=*/0); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index e210a480bc74dd..205fc66b41fcc8 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1970,7 +1970,7 @@ StreamExecutorGpuClient::RunAsync( const int64_t buffer_size = allocation.size(); if (buffer_size > 0) { TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory owning_buffer, + se::ScopedDeviceAddress owning_buffer, memory_allocator->Allocate(device_ordinal, buffer_size, /*retry_on_failure=*/true, /*memory_space=*/allocation.color())); @@ -2035,7 +2035,7 @@ StreamExecutorGpuClient::RunAsync( "buffer is not donated; allocating a fresh buffer"; int64_t allocation_size = ShapeUtil::ByteSizeOf( ShapeUtil::GetSubshape(gpu_exec->result_shape(), index)); - absl::StatusOr allocated_buffer = + absl::StatusOr> allocated_buffer = memory_allocator->Allocate(device_ordinal, allocation_size, /*retry_on_failure=*/true, /*memory_space=*/allocation->color()); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc index 5e84506057c524..88fce7477ce884 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc @@ -775,16 +775,20 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( tuple_buffer.buffers().mutable_element({}); VLOG(3) << "untuple: output_buffers[" << i << "].emplace: " << elem->opaque(); - output_buffers[i].emplace(stream_executor::OwningDeviceMemory( - *elem, device->local_device_id().value(), client->allocator())); + output_buffers[i].emplace( + stream_executor::ScopedDeviceAddress( + *elem, device->local_device_id().value(), + client->allocator())); *elem = se::DeviceAddressBase(); } } else { CHECK_EQ(output_buffers.size(), 1); auto* elem = output.buffers().mutable_element({}); VLOG(3) << "output_buffers[0].emplace: " << elem->opaque(); - output_buffers.front().emplace(stream_executor::OwningDeviceMemory( - *elem, device->local_device_id().value(), client->allocator())); + output_buffers.front().emplace( + stream_executor::ScopedDeviceAddress( + *elem, device->local_device_id().value(), + client->allocator())); *elem = se::DeviceAddressBase(); } @@ -909,10 +913,11 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( << "]: " << tracked_buffers[i]->buffer()->buffer().opaque(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {i}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( - tracked_buffers[i]->buffer()->buffer(), - device->local_hardware_id().value(), - client->allocator()))); + {i}, + MaybeOwningDeviceAddress(se::ScopedDeviceAddress( + tracked_buffers[i]->buffer()->buffer(), + device->local_hardware_id().value(), + client->allocator()))); } else { input.SetBuffer({i}, MaybeOwningDeviceAddress( tracked_buffers[i]->buffer()->buffer())); @@ -928,7 +933,7 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( ExecutionInput& input = inputs.back(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( + {}, MaybeOwningDeviceAddress(se::ScopedDeviceAddress( tracked_buffers[i]->buffer()->buffer(), device->local_hardware_id().value(), client->allocator()))); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index e342a586863001..d11f6e966f5ec2 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -1156,7 +1156,7 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, se::Stream* stream = local_device->host_to_device_stream(); TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory owned_root_table_memory, + se::ScopedDeviceAddress owned_root_table_memory, allocator->Allocate( device_ordinal, transfer_manager->GetByteSizeRequirement(tupled_parameter_shape))); @@ -1673,7 +1673,7 @@ PjRtStreamExecutorClient::RunAsync( auto it = tmp.MutableBuffers()->begin(); for (auto& v : input) { if (v.second.is_donated) { - it->second = MaybeOwningDeviceAddress(se::OwningDeviceMemory( + it->second = MaybeOwningDeviceAddress(se::ScopedDeviceAddress( v.second.buf->mem(), device->local_device_id().value(), run_options.allocator())); tmp.SetUnownedIndex(it->first); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 4b656c48fc2517..4220db893cb1dc 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -91,8 +91,8 @@ struct PjRtStreamExecutorExecutionOutput { // Donated inputs which must be freed. std::vector> to_be_released; // For PjRtStreamExecutorClient implementations that - // use OwningDeviceMemory for donated inputs. - std::vector se_to_be_released; + // use ScopedDeviceAddress for donated inputs. + std::vector> se_to_be_released; }; class PjRtStreamExecutorDevice : public PjRtDevice { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index d5bec6ba286977..2c1b89083b477d 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/pjrt/tracked_device_buffer.h" +#include #include #include #include @@ -90,7 +91,7 @@ absl::StatusOr> MakeArray( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), [&](const Shape& subshape, const ShapeIndex&) -> absl::Status { TF_ASSIGN_OR_RETURN( - se::OwningDeviceMemory device_memory, + se::ScopedDeviceAddress device_memory, client->backend().memory_allocator()->Allocate( /*device_ordinal=*/0, client->backend().transfer_manager()->GetByteSizeRequirement( diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index b5d097d79b4715..e5e8114809599e 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4113,6 +4113,8 @@ cc_library( hdrs = ["maybe_owning_device_memory.h"], deps = [ ":maybe_owning_device_address", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/service/maybe_owning_device_memory.h b/third_party/xla/xla/service/maybe_owning_device_memory.h index 897003ffb17429..40d05599971dcd 100644 --- a/third_party/xla/xla/service/maybe_owning_device_memory.h +++ b/third_party/xla/xla/service/maybe_owning_device_memory.h @@ -18,6 +18,8 @@ limitations under the License. #include "absl/base/macros.h" #include "xla/service/maybe_owning_device_address.h" +#include "xla/stream_executor/device_address.h" // IWYU pragma: keep +#include "xla/stream_executor/device_address_allocator.h" // IWYU pragma: keep #include "xla/stream_executor/device_memory.h" // IWYU pragma: keep #include "xla/stream_executor/device_memory_allocator.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 9f617478a6ea7b..4466fb094ab53d 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -184,7 +184,7 @@ cc_library( "//xla/service:hlo_runner_pjrt", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", @@ -451,8 +451,8 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:transfer_manager", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -522,8 +522,8 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", @@ -2989,7 +2989,7 @@ xla_test( "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/service:collective_ops_utils", - "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_address", "//xla/stream_executor:stream", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", @@ -3520,7 +3520,7 @@ xla_test( "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", @@ -3660,7 +3660,7 @@ xla_test( "//xla/service:generic_transfer_manager", "//xla/service:shaped_buffer", "//xla/service:stream_pool", - "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_address_allocator", "//xla/tests:xla_test_backend_predicates", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_benchmark", diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 324917cbd57df6..870a7b659bcb27 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -45,8 +45,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -113,7 +113,7 @@ class BufferDonationTest : public HloTestBase { run_options, backend_->StreamBorrowerWithPriority()); std::vector args; - std::vector> inputs_buffers; + std::vector> inputs_buffers; CHECK_EQ(argument_literals.size(), donate_arguments.size()); @@ -130,7 +130,7 @@ class BufferDonationTest : public HloTestBase { ShapedBuffer shaped_buffer = scoped_shaped_buffer.release(); CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice( stream.get(), argument_literal, shaped_buffer)); - ShapeTree input_buffers = shaped_buffer.buffers(); + ShapeTree input_buffers = shaped_buffer.buffers(); inputs_buffers.push_back(input_buffers); ShapeTree owned_buffers( argument_literal.shape()); @@ -138,7 +138,7 @@ class BufferDonationTest : public HloTestBase { [&](const ShapeIndex& index, MaybeOwningDeviceAddress* device_memory) { if (donate_argument) { - *device_memory = se::OwningDeviceMemory( + *device_memory = se::ScopedDeviceAddress( input_buffers.element(index), executor_->device_ordinal(), &memory_allocator); } else { @@ -162,7 +162,7 @@ class BufferDonationTest : public HloTestBase { } ExecutionOutput output = std::move(output_status).value(); - se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer(); + se::DeviceAddressBase result_root_buffer = output.Result().root_buffer(); LOG(INFO) << "result allocation = " << result_root_buffer.opaque() << " size = " << result_root_buffer.size(); diff --git a/third_party/xla/xla/tests/collective_ops_ffi_test.cc b/third_party/xla/xla/tests/collective_ops_ffi_test.cc index f56ef7045eca7b..21d423965efc0e 100644 --- a/third_party/xla/xla/tests/collective_ops_ffi_test.cc +++ b/third_party/xla/xla/tests/collective_ops_ffi_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/collective_ops_utils.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/stream.h" #include "xla/tests/collective_ops_e2e_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index 6421e9badcbec7..dce925c25e28d0 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/service/hlo_runner_pjrt.h" #include "xla/service/platform_util.h" #include "xla/shape.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/hlo_runner_agnostic_reference_mixin.h" @@ -174,7 +174,7 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( reference_preprocessor); } -se::DeviceMemoryAllocator* HloTestBase::GetAllocator() { +se::DeviceAddressAllocator* HloTestBase::GetAllocator() { if (allocator_ == nullptr) { allocator_ = std::make_unique( backend().default_stream_executor()); diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index 31efd1fc5ff2bb..c378860ec85a40 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -48,7 +48,7 @@ static_assert(false, #include "xla/service/computation_placer.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_runner_agnostic_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" @@ -210,7 +210,7 @@ class ABSL_DEPRECATED( static se::Platform* GetTestPlatform(); // Creates or retrieves the allocator. - se::DeviceMemoryAllocator* GetAllocator(); + se::DeviceAddressAllocator* GetAllocator(); ErrorSpec error_spec_{0.0001}; @@ -224,7 +224,7 @@ class ABSL_DEPRECATED( bool allow_mixed_precision_in_hlo_verifier, HloPredicate instruction_can_change_layout_func); - std::unique_ptr allocator_; + std::unique_ptr allocator_; }; } // namespace xla diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index ac4aec28517450..cb0675c889c052 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index 29563c202f26a2..957b24fc150f8e 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -43,8 +43,8 @@ limitations under the License. #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -57,7 +57,7 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -absl::StatusOr TestAllocator::Allocate( +absl::StatusOr> TestAllocator::Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; @@ -71,7 +71,7 @@ absl::StatusOr TestAllocator::Allocate( } absl::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) { + se::DeviceAddressBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { absl::MutexLock lock(count_mutex_); diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index cb7de54135e8db..3afeae8c003d8c 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -37,8 +37,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -53,11 +53,11 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { : se::StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).value()) {} - absl::StatusOr Allocate( + absl::StatusOr> Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) override; absl::Status Deallocate(int device_ordinal, - se::DeviceMemoryBase mem) override; + se::DeviceAddressBase mem) override; // Return the number of allocations that have been performed. int64_t allocation_count() const; diff --git a/third_party/xla/xla/tests/transfer_manager_test.cc b/third_party/xla/xla/tests/transfer_manager_test.cc index 6a4a188afd94fa..66d84eebb73fb7 100644 --- a/third_party/xla/xla/tests/transfer_manager_test.cc +++ b/third_party/xla/xla/tests/transfer_manager_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/stream_pool.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/device_address_allocator.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 5f422444fd55e9..60993b0f7d19ab 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -1007,8 +1007,8 @@ tsl_gpu_library( "//xla/service/cpu:cpu_executable", "//xla/service/gpu:gpu_symbol_repository", "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", From 5b7e6e0d94c58afd3756c2b69e28a4d32e9ba9fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Dec 2025 12:11:08 -0800 Subject: [PATCH 089/753] Reverts 21040b5b652e63d81f77509a5ce8ab27a5b1d16c PiperOrigin-RevId: 842342466 --- tensorflow/compiler/jit/xla_tensor.cc | 2 +- .../runtime/gpublas_lt_matmul_thunk_test.cc | 3 +- third_party/xla/xla/client/BUILD | 4 +- third_party/xla/xla/client/client_library.h | 2 +- third_party/xla/xla/client/local_client.cc | 4 +- third_party/xla/xla/client/local_client.h | 4 +- third_party/xla/xla/core/collectives/BUILD | 2 +- .../xla/xla/core/collectives/communicator.h | 40 ++++++++-------- third_party/xla/xla/ffi/BUILD | 10 ++-- third_party/xla/xla/ffi/api/BUILD | 4 +- third_party/xla/xla/ffi/api/c_api_internal.h | 2 +- third_party/xla/xla/ffi/api/ffi_test.cc | 48 +++++++++---------- third_party/xla/xla/ffi/call_frame.cc | 20 ++++---- third_party/xla/xla/ffi/call_frame.h | 14 +++--- third_party/xla/xla/ffi/call_frame_test.cc | 20 ++++---- third_party/xla/xla/ffi/ffi.h | 12 ++--- third_party/xla/xla/ffi/ffi_api.cc | 4 +- third_party/xla/xla/ffi/ffi_test.cc | 32 ++++++------- .../xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc | 4 +- .../xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc | 23 ++++----- .../xla/pjrt/pjrt_stream_executor_client.cc | 4 +- .../xla/pjrt/pjrt_stream_executor_client.h | 4 +- .../xla/pjrt/tracked_device_buffer_test.cc | 3 +- third_party/xla/xla/service/BUILD | 2 - .../xla/service/maybe_owning_device_memory.h | 2 - third_party/xla/xla/tests/BUILD | 16 +++---- .../xla/xla/tests/buffer_donation_test.cc | 12 ++--- .../xla/xla/tests/collective_ops_ffi_test.cc | 2 +- third_party/xla/xla/tests/hlo_test_base.cc | 4 +- third_party/xla/xla/tests/hlo_test_base.h | 6 +-- .../xla/tests/local_client_execute_test.cc | 2 +- .../xla/xla/tests/local_client_test_base.cc | 8 ++-- .../xla/xla/tests/local_client_test_base.h | 8 ++-- .../xla/xla/tests/transfer_manager_test.cc | 2 +- third_party/xla/xla/tools/BUILD | 2 +- 35 files changed, 158 insertions(+), 173 deletions(-) diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d6792cd7802d96..e9cdad219dd28d 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -55,7 +55,7 @@ absl::Status XlaTensor::AllocateShapedBuffer(DataType dtype, xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(se::ScopedDeviceAddress buffer, + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false, subshape.layout().memory_space())); diff --git a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc index 77a6ac88f8ff70..ccdf653ca1862e 100644 --- a/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc +++ b/third_party/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h" #include -#include #include #include #include @@ -183,7 +182,7 @@ class GpuBlasLtThunkBuilder { se::StreamExecutorMemoryAllocator allocator_; se::GpuComputeCapability gpu_comp_; std::deque allocs_; - std::vector> mem_buffers_; + std::vector mem_buffers_; }; void GpuBlasLtMatmulThunkTest::CreateExecuteThunksFromHLO( diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index fac2d9343ff1d0..c2801fa3fa8410 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -128,7 +128,7 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:source_map_util", "//xla/service:stream_pool", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -181,7 +181,7 @@ cc_library( "//xla/service:compile_only_service", "//xla/service:local_service", "//xla/service:platform_util", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/tsl/platform:logging", diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h index 42d0f34202e092..0e4f3a9a24dd22 100644 --- a/third_party/xla/xla/client/client_library.h +++ b/third_party/xla/xla/client/client_library.h @@ -36,7 +36,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/service/compile_only_service.h" #include "xla/service/local_service.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index e1f348a755521d..cc383a9aa81b34 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -45,7 +45,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/errors.h" @@ -512,7 +512,7 @@ absl::StatusOr> LocalClient::LoadInternal( absl::StatusOr LocalClient::LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, - se::DeviceAddressAllocator* allocator) { + se::DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 3c237ef37a1973..3ccda5d43f6794 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -39,7 +39,7 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" #include "xla/shape_tree.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/statusor.h" @@ -183,7 +183,7 @@ class LocalClient : public Client { // device is used. absl::StatusOr LiteralToShapedBuffer( const LiteralSlice& literal, int device_ordinal, - se::DeviceAddressAllocator* allocator = nullptr); + se::DeviceMemoryAllocator* allocator = nullptr); // Transfer the BorrowingLiteral to the device with the given ordinal. absl::StatusOr TransferToLocalServer( diff --git a/third_party/xla/xla/core/collectives/BUILD b/third_party/xla/xla/core/collectives/BUILD index 06d3ef7f6c9aed..1b0398aaaf4801 100644 --- a/third_party/xla/xla/core/collectives/BUILD +++ b/third_party/xla/xla/core/collectives/BUILD @@ -73,7 +73,7 @@ cc_library( "//xla:future", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/core/collectives/communicator.h b/third_party/xla/xla/core/collectives/communicator.h index 4be35fb52163f7..0f60a859db854d 100644 --- a/third_party/xla/xla/core/collectives/communicator.h +++ b/third_party/xla/xla/core/collectives/communicator.h @@ -28,7 +28,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/future.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -65,7 +65,7 @@ class Communicator { // Register `buffer_range` once for efficient collective operations (i.e. on // NCCL backend it registers the buffer for zero-copy collective operations). // - virtual absl::Status RegisterBufferOnce(se::DeviceAddressBase buffer_range, + virtual absl::Status RegisterBufferOnce(se::DeviceMemoryBase buffer_range, int device_ordinal, bool use_symmetric_buffer) { return Unimplemented("User-managed buffer registration is not supported"); @@ -91,40 +91,40 @@ class Communicator { // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. - virtual Future<> AllReduce(stream_executor::DeviceAddressBase send_buffer, - stream_executor::DeviceAddressBase recv_buffer, + virtual Future<> AllReduce(stream_executor::DeviceMemoryBase send_buffer, + stream_executor::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, const Executor& executor) = 0; // Copy data in `send_buff` from the root device to the `recv_buff` on // all other devices. - virtual Future<> Broadcast(se::DeviceAddressBase send_buffer, - se::DeviceAddressBase recv_buffer, + virtual Future<> Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, RankId root, const Executor& executor) = 0; // Reduce data in `send_buff` from all devices using the `reduction_kind` // operation and leave the reduced result scattered over the devices so that // the `recv_buff` on rank `i` will contain the i-th block of the result. - virtual Future<> ReduceScatter(se::DeviceAddressBase send_buffer, - se::DeviceAddressBase recv_buffer, + virtual Future<> ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, ReductionKind reduction_kind, const Executor& executor) = 0; // Gather `count` values from all devices into `recv_buffer`, receiving data // from rank `i` at offset `i * sendcount`. - virtual Future<> AllGather(se::DeviceAddressBase send_buffer, - se::DeviceAddressBase recv_buffer, + virtual Future<> AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, const Executor& executor) = 0; // Sends data from `send_buffer` to `target_ranks` and receives data from // `source_rank` into `recv_buffer`. If `source_rank` is not specified, the // output is filled with zeros. - virtual Future<> CollectivePermute(se::DeviceAddressBase send_buffer, - se::DeviceAddressBase recv_buffer, + virtual Future<> CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, std::optional source_rank, absl::Span target_ranks, @@ -133,30 +133,30 @@ class Communicator { // Sends `count` values from `send_buffers` to other ranks and receives data // from other ranks into `recv_buffers`. virtual Future<> AllToAll( - absl::InlinedVector send_buffers, - absl::InlinedVector recv_buffers, + absl::InlinedVector send_buffers, + absl::InlinedVector recv_buffers, PrimitiveType dtype, size_t count, const Executor& executor) = 0; // Send data from `send_buff` to rank `peer`. - virtual Future<> Send(se::DeviceAddressBase send_buffer, PrimitiveType dtype, + virtual Future<> Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; // Receive data from rank `peer` into `recv_buff`. - virtual Future<> Recv(se::DeviceAddressBase recv_buffer, PrimitiveType dtype, + virtual Future<> Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; // Send data from `send_buff` to rank `recv_buff` (one-way send). - virtual Future<> Send(se::DeviceAddressBase recv_buffer, - se::DeviceAddressBase send_buffer, PrimitiveType dtype, + virtual Future<> Send(se::DeviceMemoryBase recv_buffer, + se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) { return Unimplemented("One-way send is not implemented"); } // Receive data from rank `peer` into `recv_buff` (one-way recv). - virtual Future<> Recv(se::DeviceAddressBase recv_buffer, - se::DeviceAddressBase send_buffer, PrimitiveType dtype, + virtual Future<> Recv(se::DeviceMemoryBase recv_buffer, + se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) { return Unimplemented("One-way recv is not implemented"); } diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index f14764091594bc..41c825e3599ea2 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -37,7 +37,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/tsl/platform:errors", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -58,7 +58,7 @@ xla_cc_test( ":call_frame", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:test", "//xla/tsl/platform:test_benchmark", @@ -149,7 +149,7 @@ cc_library( "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -182,8 +182,8 @@ cc_library( "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", "//xla/service:platform_util", - "//xla/stream_executor:device_address", "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", @@ -299,7 +299,7 @@ xla_cc_test( "//xla/backends/cpu:ffi", "//xla/backends/gpu:ffi", "//xla/ffi/api:c_api", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index dc4551d8e2fecc..41889027b9ddd3 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -91,8 +91,8 @@ xla_cc_test( "//xla/ffi:execution_state", "//xla/ffi:ffi_api", "//xla/ffi:type_registry", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h index d9070080f3a4a6..d0baf4fc3b7bb0 100644 --- a/third_party/xla/xla/ffi/api/c_api_internal.h +++ b/third_party/xla/xla/ffi/api/c_api_internal.h @@ -93,7 +93,7 @@ typedef XLA_FFI_Error* XLA_FFI_INTERNAL_IntraOpThreadPool_Get( typedef XLA_FFI_Error* XLA_FFI_INTERNAL_Stream_Get( XLA_FFI_ExecutionContext* ctx, void** stream); -// Returns a pointer to device memory allocator (`se::DeviceAddressAllocator` +// Returns a pointer to device memory allocator (`se::DeviceMemoryAllocator` // pointer) which allows to allocate memory inside a custom call from the same // allocator as XLA (i.e. it allows to construct scratch memory allocator). typedef XLA_FFI_Error* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get( diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index 81578f564956fd..e3345ebe915146 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -45,8 +45,8 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/ffi/type_registry.h" #include "xla/primitive_util.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -522,7 +522,7 @@ TEST(FfiTest, DeviceOrdinal) { TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -544,7 +544,7 @@ TEST(FfiTest, AnyBufferArgument) { TEST(FfiTest, BufferArgument) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -562,7 +562,7 @@ TEST(FfiTest, BufferArgument) { TEST(FfiTest, AnyBufferResult) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -594,7 +594,7 @@ TEST(FfiTest, MissingBufferArgument) { TEST(FfiTest, WrongRankBufferArgument) { std::vector storage(4, 0.0); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -611,7 +611,7 @@ TEST(FfiTest, WrongRankBufferArgument) { TEST(FfiTest, WrongTypeBufferArgument) { std::vector storage(4, 0.0); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); @@ -648,7 +648,7 @@ TEST(FfiTest, WrongNumberOfArguments) { TEST(FfiTest, TokenArgument) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceAddressBase(), PrimitiveType::TOKEN, + builder.AddBufferArg(se::DeviceMemoryBase(), PrimitiveType::TOKEN, /*dims=*/{}); auto call_frame = builder.Build(); @@ -665,7 +665,7 @@ TEST(FfiTest, TokenArgument) { TEST(FfiTest, RemainingArgs) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -694,7 +694,7 @@ TEST(FfiTest, RemainingArgs) { TEST(FfiTest, RemainingRets) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -724,7 +724,7 @@ TEST(FfiTest, RemainingRets) { TEST(FfiTest, OptionalArgs) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -785,7 +785,7 @@ TEST(FfiTest, OptionalArgs) { TEST(FfiTest, OptionalRets) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -854,7 +854,7 @@ TEST(FfiTest, AutoBinding) { }); std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder::AttributesBuilder attrs; attrs.Insert(kI32, 42); @@ -873,8 +873,7 @@ TEST(FfiTest, AutoBindingResult) { Ffi::BindTo(+[](Result buffer) { return Error::Success(); }); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); - builder.AddBufferRet(se::DeviceAddressBase(), PrimitiveType::F32, - /*dims=*/{}); + builder.AddBufferRet(se::DeviceMemoryBase(), PrimitiveType::F32, /*dims=*/{}); auto call_frame = builder.Build(); auto status = Call(*handler, call_frame); @@ -1410,22 +1409,19 @@ TEST(FfiTest, ScratchAllocator) { static void* kAddr = reinterpret_cast(0xDEADBEEF); // A test only memory allocator that returns a fixed memory address. - struct TestDeviceMemoryAllocator final : public se::DeviceAddressAllocator { + struct TestDeviceMemoryAllocator final : public se::DeviceMemoryAllocator { size_t count; TestDeviceMemoryAllocator() - : se::DeviceAddressAllocator(nullptr), count(0) {} + : se::DeviceMemoryAllocator(nullptr), count(0) {} - absl::StatusOr> Allocate(int, - uint64_t size, - bool, - int64_t) final { + absl::StatusOr Allocate(int, uint64_t size, bool, + int64_t) final { count++; - return se::ScopedDeviceAddress( - se::DeviceAddressBase(kAddr, size), 0, this); + return se::OwningDeviceMemory(se::DeviceMemoryBase(kAddr, size), 0, this); } - absl::Status Deallocate(int, se::DeviceAddressBase mem) final { + absl::Status Deallocate(int, se::DeviceMemoryBase mem) final { count--; EXPECT_EQ(mem.opaque(), kAddr); return absl::OkStatus(); @@ -1592,7 +1588,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(BufferR2F32Handler, BufferR2F32Function); TEST(FfiTest, DefineAutoSymbol) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -1608,7 +1604,7 @@ TEST(FfiTest, DefineAutoSymbol) { //===----------------------------------------------------------------------===// static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { - se::DeviceAddressBase memory; + se::DeviceMemoryBase memory; std::vector dims(4, 1); CallFrameBuilder builder(/*num_args=*/num_args, /*num_rets=*/0); diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index f0c17215c2dafd..ad7c71c98f8cd6 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -35,7 +35,7 @@ limitations under the License. #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/platform/errors.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -47,7 +47,7 @@ namespace xla::ffi { //===----------------------------------------------------------------------===// struct CallFrameBuilder::Buffer { - se::DeviceAddressBase memory; + se::DeviceMemoryBase memory; PrimitiveType type; absl::InlinedVector dims; }; @@ -84,7 +84,7 @@ CallFrameBuilder::CallFrameBuilder(size_t num_args, size_t num_rets) { CallFrameBuilder::~CallFrameBuilder() = default; -void CallFrameBuilder::AddBufferArg(se::DeviceAddressBase memory, +void CallFrameBuilder::AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims) { DCHECK(args_.capacity() > args_.size()) @@ -95,10 +95,10 @@ void CallFrameBuilder::AddBufferArg(se::DeviceAddressBase memory, void CallFrameBuilder::AddTokenArg() { DCHECK(args_.capacity() > args_.size()) << "CallFrame builder `num_args` argument was too small"; - args_.push_back(Buffer{se::DeviceAddressBase(), PrimitiveType::TOKEN, {}}); + args_.push_back(Buffer{se::DeviceMemoryBase(), PrimitiveType::TOKEN, {}}); } -void CallFrameBuilder::AddBufferRet(se::DeviceAddressBase memory, +void CallFrameBuilder::AddBufferRet(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims) { DCHECK(rets_.capacity() > rets_.size()) @@ -109,7 +109,7 @@ void CallFrameBuilder::AddBufferRet(se::DeviceAddressBase memory, void CallFrameBuilder::AddTokenRet() { DCHECK(rets_.capacity() > rets_.size()) << "CallFrame builder `num_rets` argument was too small"; - rets_.push_back(Buffer{se::DeviceAddressBase(), PrimitiveType::TOKEN, {}}); + rets_.push_back(Buffer{se::DeviceMemoryBase(), PrimitiveType::TOKEN, {}}); } void CallFrameBuilder::AddAttributes(AttributesMap attrs) { @@ -557,8 +557,8 @@ std::unique_ptr CallFrame::FixUpAttrs( //===----------------------------------------------------------------------===// absl::Status CallFrame::UpdateWithBuffers( - absl::Span args, - absl::Span rets) { + absl::Span args, + absl::Span rets) { if (ABSL_PREDICT_FALSE(args.size() != arguments_->args.size())) { return InvalidArgument("Invalid number of updated arguments: %d vs %d", args.size(), arguments_->args.size()); @@ -587,8 +587,8 @@ CallFrame CallFrame::Copy() const { } absl::StatusOr CallFrame::CopyWithBuffers( - absl::Span args, - absl::Span rets) const { + absl::Span args, + absl::Span rets) const { CallFrame clone(CopyArgs(*arguments_), CopyRets(*results_), attributes_); TF_RETURN_IF_ERROR(clone.UpdateWithBuffers(args, rets)); return clone; diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 5433d4be990d42..32dceead1d9b4b 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -30,7 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -76,12 +76,12 @@ class CallFrameBuilder { CallFrame Build(); - void AddBufferArg(se::DeviceAddressBase memory, PrimitiveType type, + void AddBufferArg(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims); void AddTokenArg(); - void AddBufferRet(se::DeviceAddressBase memory, PrimitiveType type, + void AddBufferRet(se::DeviceMemoryBase memory, PrimitiveType type, absl::Span dims); void AddTokenRet(); @@ -117,16 +117,16 @@ class CallFrame { // array (buffer) arguments and results are known at compile time. Instead of // rebuilding the call frame from scratch on every execution, we can just // update the arguments and results with new pointers to device memory. - absl::Status UpdateWithBuffers(absl::Span args, - absl::Span rets); + absl::Status UpdateWithBuffers(absl::Span args, + absl::Span rets); // Creates a copy of the call frame. CallFrame Copy() const; // Creates a copy of the call frame with updated arguments and results. absl::StatusOr CopyWithBuffers( - absl::Span args, - absl::Span rets) const; + absl::Span args, + absl::Span rets) const; // Builds an XLA_FFI_CallFrame from owned arguments and attributes. XLA_FFI_CallFrame Build( diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index b58e2d9a2537b6..f73461fc7d297f 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/test.h" #include "xla/tsl/platform/test_benchmark.h" @@ -34,8 +34,8 @@ limitations under the License. namespace xla::ffi { TEST(CallFrameTest, UpdateCallFrame) { - se::DeviceAddressBase mem0(reinterpret_cast(0x12345678), 1024); - se::DeviceAddressBase mem1(reinterpret_cast(0x87654321), 1024); + se::DeviceMemoryBase mem0(reinterpret_cast(0x12345678), 1024); + se::DeviceMemoryBase mem1(reinterpret_cast(0x87654321), 1024); std::vector dims = {1, 2, 3, 4}; @@ -116,7 +116,7 @@ TEST(CallFrameTest, UpdateCallFrame) { void BM_AddBufferArg(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; for (auto _ : state) { @@ -151,17 +151,17 @@ void BM_AddAttributes(benchmark::State& state) { void BM_UpdateCallFrame(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; CallFrameBuilder builder(num_args, /*num_rets=*/0); for (size_t i = 0; i < num_args; ++i) { - builder.AddBufferArg(se::DeviceAddressBase(nullptr, 1024), + builder.AddBufferArg(se::DeviceMemoryBase(nullptr, 1024), PrimitiveType::F32, dims); } CallFrame call_frame = builder.Build(); - std::vector updated_args(num_args, memory); + std::vector updated_args(num_args, memory); for (auto _ : state) { auto updated_call_frame = @@ -173,17 +173,17 @@ void BM_UpdateCallFrame(benchmark::State& state) { void BM_UpdateCallFrameInPlace(benchmark::State& state) { size_t num_args = state.range(0); - se::DeviceAddressBase memory(reinterpret_cast(0x12345678), 1024); + se::DeviceMemoryBase memory(reinterpret_cast(0x12345678), 1024); std::vector dims = {1, 2, 3, 4}; CallFrameBuilder builder(num_args, /*num_rets=*/0); for (size_t i = 0; i < num_args; ++i) { - builder.AddBufferArg(se::DeviceAddressBase(nullptr, 1024), + builder.AddBufferArg(se::DeviceMemoryBase(nullptr, 1024), PrimitiveType::F32, dims); } CallFrame call_frame = builder.Build(); - std::vector updated_args(num_args, memory); + std::vector updated_args(num_args, memory); for (auto _ : state) { benchmark::DoNotOptimize( diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 4e1849a190d327..da6303e14faef7 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -50,7 +50,7 @@ limitations under the License. #include "xla/ffi/type_registry.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/primitive_util.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/types.h" // IWYU pragma: keep @@ -137,8 +137,8 @@ class AnyBuffer { return reinterpret_cast(buf_->data); } - se::DeviceAddressBase device_memory() const { - return se::DeviceAddressBase(untyped_data(), size_bytes()); + se::DeviceMemoryBase device_memory() const { + return se::DeviceMemoryBase(untyped_data(), size_bytes()); } private: @@ -182,9 +182,9 @@ class Buffer { return reinterpret_cast*>(untyped_data()); } - se::DeviceAddress> device_memory() const { - return se::DeviceAddress>( - se::DeviceAddressBase(untyped_data(), size_bytes())); + se::DeviceMemory> device_memory() const { + return se::DeviceMemory>( + se::DeviceMemoryBase(untyped_data(), size_bytes())); } private: diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 3f0de64033061e..31287ac7587ef4 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -47,8 +47,8 @@ limitations under the License. #include "xla/ffi/ffi_structs.h" #include "xla/ffi/type_registry.h" #include "xla/service/platform_util.h" -#include "xla/stream_executor/device_address.h" #include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/tsl/platform/logging.h" @@ -795,7 +795,7 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free( absl::Status status = gpu->allocator->Deallocate( args->ctx->device_ordinal, - stream_executor::DeviceAddressBase(args->data, args->size)); + stream_executor::DeviceMemoryBase(args->data, args->size)); if (!status.ok()) { return new XLA_FFI_Error{std::move(status)}; } diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 0369c8cc1946e5..8f0b00244c0a93 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -43,7 +43,7 @@ limitations under the License. #include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" #include "xla/ffi/type_registry.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" @@ -179,7 +179,7 @@ TEST(FfiTest, CatchExceptionExplicit) { TEST(FfiTest, WrongNumArgs) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceAddressBase(nullptr), PrimitiveType::F32, {}); + builder.AddBufferArg(se::DeviceMemoryBase(nullptr), PrimitiveType::F32, {}); auto call_frame = builder.Build(); auto handler = Ffi::Bind().Arg().Arg().To( @@ -579,7 +579,7 @@ TEST(FfiTest, DecodingErrors) { TEST(FfiTest, AnyBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -614,7 +614,7 @@ TEST(FfiTest, AnyBufferArgument) { TEST(FfiTest, TypedAndRankedBufferArgument) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), storage.size() * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), storage.size() * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -642,8 +642,8 @@ TEST(FfiTest, TypedAndRankedBufferArgument) { TEST(FfiTest, ComplexBufferArgument) { std::vector> storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), - storage.size() * sizeof(std::complex)); + se::DeviceMemoryBase memory(storage.data(), + storage.size() * sizeof(std::complex)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::C64, /*dims=*/{2, 2}); @@ -662,7 +662,7 @@ TEST(FfiTest, ComplexBufferArgument) { TEST(FfiTest, TokenArgument) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); - builder.AddBufferArg(se::DeviceAddressBase(), PrimitiveType::TOKEN, + builder.AddBufferArg(se::DeviceMemoryBase(), PrimitiveType::TOKEN, /*dims=*/{}); auto call_frame = builder.Build(); @@ -679,7 +679,7 @@ TEST(FfiTest, TokenArgument) { TEST(FfiTest, WrongRankBufferArgument) { std::vector storage(4, 0.0); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -697,7 +697,7 @@ TEST(FfiTest, WrongRankBufferArgument) { TEST(FfiTest, WrongTypeBufferArgument) { std::vector storage(4, 0.0); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(int32_t)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(int32_t)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); @@ -715,7 +715,7 @@ TEST(FfiTest, WrongTypeBufferArgument) { TEST(FfiTest, RemainingArgs) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -743,7 +743,7 @@ TEST(FfiTest, RemainingArgs) { TEST(FfiTest, RemainingRets) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -772,7 +772,7 @@ TEST(FfiTest, RemainingRets) { TEST(FfiTest, OptionalArgs) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -833,7 +833,7 @@ TEST(FfiTest, OptionalArgs) { TEST(FfiTest, OptionalRets) { std::vector storage(4, 0.0f); - se::DeviceAddressBase memory(storage.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); @@ -975,8 +975,8 @@ TEST(FfiTest, UpdateBufferArgumentsAndResults) { std::vector storage0(4, 0.0f); std::vector storage1(4, 0.0f); - se::DeviceAddressBase memory0(storage0.data(), 4 * sizeof(float)); - se::DeviceAddressBase memory1(storage1.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory0(storage0.data(), 4 * sizeof(float)); + se::DeviceMemoryBase memory1(storage1.data(), 4 * sizeof(float)); std::vector dims = {2, 2}; @@ -1169,7 +1169,7 @@ TEST(FfiTest, PlatformStream) { //===----------------------------------------------------------------------===// static CallFrameBuilder WithBufferArgs(size_t num_args, size_t rank = 4) { - se::DeviceAddressBase memory; + se::DeviceMemoryBase memory; std::vector dims(4, 1); CallFrameBuilder builder(/*num_args=*/num_args, /*num_rets=*/0); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index 205fc66b41fcc8..e210a480bc74dd 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -1970,7 +1970,7 @@ StreamExecutorGpuClient::RunAsync( const int64_t buffer_size = allocation.size(); if (buffer_size > 0) { TF_ASSIGN_OR_RETURN( - se::ScopedDeviceAddress owning_buffer, + se::OwningDeviceMemory owning_buffer, memory_allocator->Allocate(device_ordinal, buffer_size, /*retry_on_failure=*/true, /*memory_space=*/allocation.color())); @@ -2035,7 +2035,7 @@ StreamExecutorGpuClient::RunAsync( "buffer is not donated; allocating a fresh buffer"; int64_t allocation_size = ShapeUtil::ByteSizeOf( ShapeUtil::GetSubshape(gpu_exec->result_shape(), index)); - absl::StatusOr> allocated_buffer = + absl::StatusOr allocated_buffer = memory_allocator->Allocate(device_ordinal, allocation_size, /*retry_on_failure=*/true, /*memory_space=*/allocation->color()); diff --git a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc index 88fce7477ce884..5e84506057c524 100644 --- a/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc +++ b/third_party/xla/xla/pjrt/gpu/tfrt/tfrt_gpu_executable.cc @@ -775,20 +775,16 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( tuple_buffer.buffers().mutable_element({}); VLOG(3) << "untuple: output_buffers[" << i << "].emplace: " << elem->opaque(); - output_buffers[i].emplace( - stream_executor::ScopedDeviceAddress( - *elem, device->local_device_id().value(), - client->allocator())); + output_buffers[i].emplace(stream_executor::OwningDeviceMemory( + *elem, device->local_device_id().value(), client->allocator())); *elem = se::DeviceAddressBase(); } } else { CHECK_EQ(output_buffers.size(), 1); auto* elem = output.buffers().mutable_element({}); VLOG(3) << "output_buffers[0].emplace: " << elem->opaque(); - output_buffers.front().emplace( - stream_executor::ScopedDeviceAddress( - *elem, device->local_device_id().value(), - client->allocator())); + output_buffers.front().emplace(stream_executor::OwningDeviceMemory( + *elem, device->local_device_id().value(), client->allocator())); *elem = se::DeviceAddressBase(); } @@ -913,11 +909,10 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( << "]: " << tracked_buffers[i]->buffer()->buffer().opaque(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {i}, - MaybeOwningDeviceAddress(se::ScopedDeviceAddress( - tracked_buffers[i]->buffer()->buffer(), - device->local_hardware_id().value(), - client->allocator()))); + {i}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( + tracked_buffers[i]->buffer()->buffer(), + device->local_hardware_id().value(), + client->allocator()))); } else { input.SetBuffer({i}, MaybeOwningDeviceAddress( tracked_buffers[i]->buffer()->buffer())); @@ -933,7 +928,7 @@ absl::StatusOr TfrtGpuExecutable::ExecuteHelper( ExecutionInput& input = inputs.back(); if (buffer_is_donated[i]) { input.SetUnownedBuffer( - {}, MaybeOwningDeviceAddress(se::ScopedDeviceAddress( + {}, MaybeOwningDeviceAddress(se::OwningDeviceMemory( tracked_buffers[i]->buffer()->buffer(), device->local_hardware_id().value(), client->allocator()))); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index d11f6e966f5ec2..e342a586863001 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -1156,7 +1156,7 @@ MakeTupleHelper(PjRtStreamExecutorClient* client, se::Stream* stream = local_device->host_to_device_stream(); TF_ASSIGN_OR_RETURN( - se::ScopedDeviceAddress owned_root_table_memory, + se::OwningDeviceMemory owned_root_table_memory, allocator->Allocate( device_ordinal, transfer_manager->GetByteSizeRequirement(tupled_parameter_shape))); @@ -1673,7 +1673,7 @@ PjRtStreamExecutorClient::RunAsync( auto it = tmp.MutableBuffers()->begin(); for (auto& v : input) { if (v.second.is_donated) { - it->second = MaybeOwningDeviceAddress(se::ScopedDeviceAddress( + it->second = MaybeOwningDeviceAddress(se::OwningDeviceMemory( v.second.buf->mem(), device->local_device_id().value(), run_options.allocator())); tmp.SetUnownedIndex(it->first); diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 4220db893cb1dc..4b656c48fc2517 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -91,8 +91,8 @@ struct PjRtStreamExecutorExecutionOutput { // Donated inputs which must be freed. std::vector> to_be_released; // For PjRtStreamExecutorClient implementations that - // use ScopedDeviceAddress for donated inputs. - std::vector> se_to_be_released; + // use OwningDeviceMemory for donated inputs. + std::vector se_to_be_released; }; class PjRtStreamExecutorDevice : public PjRtDevice { diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc index 2c1b89083b477d..d5bec6ba286977 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/pjrt/tracked_device_buffer.h" -#include #include #include #include @@ -91,7 +90,7 @@ absl::StatusOr> MakeArray( client->backend().transfer_manager()->HostShapeToDeviceShape(shape), [&](const Shape& subshape, const ShapeIndex&) -> absl::Status { TF_ASSIGN_OR_RETURN( - se::ScopedDeviceAddress device_memory, + se::OwningDeviceMemory device_memory, client->backend().memory_allocator()->Allocate( /*device_ordinal=*/0, client->backend().transfer_manager()->GetByteSizeRequirement( diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index e5e8114809599e..b5d097d79b4715 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -4113,8 +4113,6 @@ cc_library( hdrs = ["maybe_owning_device_memory.h"], deps = [ ":maybe_owning_device_address", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/service/maybe_owning_device_memory.h b/third_party/xla/xla/service/maybe_owning_device_memory.h index 40d05599971dcd..897003ffb17429 100644 --- a/third_party/xla/xla/service/maybe_owning_device_memory.h +++ b/third_party/xla/xla/service/maybe_owning_device_memory.h @@ -18,8 +18,6 @@ limitations under the License. #include "absl/base/macros.h" #include "xla/service/maybe_owning_device_address.h" -#include "xla/stream_executor/device_address.h" // IWYU pragma: keep -#include "xla/stream_executor/device_address_allocator.h" // IWYU pragma: keep #include "xla/stream_executor/device_memory.h" // IWYU pragma: keep #include "xla/stream_executor/device_memory_allocator.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 4466fb094ab53d..9f617478a6ea7b 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -184,7 +184,7 @@ cc_library( "//xla/service:hlo_runner_pjrt", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", @@ -451,8 +451,8 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:stream_pool", "//xla/service:transfer_manager", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", @@ -522,8 +522,8 @@ xla_test( "//xla/service:hlo_module_config", "//xla/service:maybe_owning_device_address", "//xla/service:shaped_buffer", - "//xla/stream_executor:device_address", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", @@ -2989,7 +2989,7 @@ xla_test( "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", "//xla/service:collective_ops_utils", - "//xla/stream_executor:device_address", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:errors", @@ -3520,7 +3520,7 @@ xla_test( "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", @@ -3660,7 +3660,7 @@ xla_test( "//xla/service:generic_transfer_manager", "//xla/service:shaped_buffer", "//xla/service:stream_pool", - "//xla/stream_executor:device_address_allocator", + "//xla/stream_executor:device_memory_allocator", "//xla/tests:xla_test_backend_predicates", "//xla/tsl/platform:statusor", "//xla/tsl/platform:test_benchmark", diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 870a7b659bcb27..324917cbd57df6 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -45,8 +45,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -113,7 +113,7 @@ class BufferDonationTest : public HloTestBase { run_options, backend_->StreamBorrowerWithPriority()); std::vector args; - std::vector> inputs_buffers; + std::vector> inputs_buffers; CHECK_EQ(argument_literals.size(), donate_arguments.size()); @@ -130,7 +130,7 @@ class BufferDonationTest : public HloTestBase { ShapedBuffer shaped_buffer = scoped_shaped_buffer.release(); CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice( stream.get(), argument_literal, shaped_buffer)); - ShapeTree input_buffers = shaped_buffer.buffers(); + ShapeTree input_buffers = shaped_buffer.buffers(); inputs_buffers.push_back(input_buffers); ShapeTree owned_buffers( argument_literal.shape()); @@ -138,7 +138,7 @@ class BufferDonationTest : public HloTestBase { [&](const ShapeIndex& index, MaybeOwningDeviceAddress* device_memory) { if (donate_argument) { - *device_memory = se::ScopedDeviceAddress( + *device_memory = se::OwningDeviceMemory( input_buffers.element(index), executor_->device_ordinal(), &memory_allocator); } else { @@ -162,7 +162,7 @@ class BufferDonationTest : public HloTestBase { } ExecutionOutput output = std::move(output_status).value(); - se::DeviceAddressBase result_root_buffer = output.Result().root_buffer(); + se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer(); LOG(INFO) << "result allocation = " << result_root_buffer.opaque() << " size = " << result_root_buffer.size(); diff --git a/third_party/xla/xla/tests/collective_ops_ffi_test.cc b/third_party/xla/xla/tests/collective_ops_ffi_test.cc index 21d423965efc0e..f56ef7045eca7b 100644 --- a/third_party/xla/xla/tests/collective_ops_ffi_test.cc +++ b/third_party/xla/xla/tests/collective_ops_ffi_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/collective_ops_utils.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_address.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" #include "xla/tests/collective_ops_e2e_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index dce925c25e28d0..6421e9badcbec7 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/service/hlo_runner_pjrt.h" #include "xla/service/platform_util.h" #include "xla/shape.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/hlo_runner_agnostic_reference_mixin.h" @@ -174,7 +174,7 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( reference_preprocessor); } -se::DeviceAddressAllocator* HloTestBase::GetAllocator() { +se::DeviceMemoryAllocator* HloTestBase::GetAllocator() { if (allocator_ == nullptr) { allocator_ = std::make_unique( backend().default_stream_executor()); diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index c378860ec85a40..31efd1fc5ff2bb 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -48,7 +48,7 @@ static_assert(false, #include "xla/service/computation_placer.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_runner_agnostic_reference_mixin.h" #include "xla/tests/hlo_runner_agnostic_test_base.h" @@ -210,7 +210,7 @@ class ABSL_DEPRECATED( static se::Platform* GetTestPlatform(); // Creates or retrieves the allocator. - se::DeviceAddressAllocator* GetAllocator(); + se::DeviceMemoryAllocator* GetAllocator(); ErrorSpec error_spec_{0.0001}; @@ -224,7 +224,7 @@ class ABSL_DEPRECATED( bool allow_mixed_precision_in_hlo_verifier, HloPredicate instruction_can_change_layout_func); - std::unique_ptr allocator_; + std::unique_ptr allocator_; }; } // namespace xla diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index cb0675c889c052..ac4aec28517450 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index 957b24fc150f8e..29563c202f26a2 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -43,8 +43,8 @@ limitations under the License. #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -57,7 +57,7 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -absl::StatusOr> TestAllocator::Allocate( +absl::StatusOr TestAllocator::Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; @@ -71,7 +71,7 @@ absl::StatusOr> TestAllocator::Allocate( } absl::Status TestAllocator::Deallocate(int device_ordinal, - se::DeviceAddressBase mem) { + se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { absl::MutexLock lock(count_mutex_); diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index 3afeae8c003d8c..cb7de54135e8db 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -37,8 +37,8 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" #include "xla/service/transfer_manager.h" -#include "xla/stream_executor/device_address.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" @@ -53,11 +53,11 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { : se::StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).value()) {} - absl::StatusOr> Allocate( + absl::StatusOr Allocate( int device_ordinal, uint64_t size, bool retry_on_failure, int64_t memory_space) override; absl::Status Deallocate(int device_ordinal, - se::DeviceAddressBase mem) override; + se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64_t allocation_count() const; diff --git a/third_party/xla/xla/tests/transfer_manager_test.cc b/third_party/xla/xla/tests/transfer_manager_test.cc index 66d84eebb73fb7..6a4a188afd94fa 100644 --- a/third_party/xla/xla/tests/transfer_manager_test.cc +++ b/third_party/xla/xla/tests/transfer_manager_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/service/stream_pool.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/device_address_allocator.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 60993b0f7d19ab..5f422444fd55e9 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -1007,8 +1007,8 @@ tsl_gpu_library( "//xla/service/cpu:cpu_executable", "//xla/service/gpu:gpu_symbol_repository", "//xla/service/gpu/autotuning:autotuner_util", - "//xla/stream_executor:device_address_allocator", "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", From 27c6b3b944f7e5e9847d9597d8d6ec51e25c7472 Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Tue, 9 Dec 2025 12:33:17 -0800 Subject: [PATCH 090/753] Integrate LLVM at llvm/llvm-project@c6e23ab80753 Updates LLVM usage to match [c6e23ab80753](https://github.com/llvm/llvm-project/commit/c6e23ab80753) PiperOrigin-RevId: 842350626 --- .../xla/third_party/llvm/workspace.bzl | 4 +- .../xla/third_party/shardy/temporary.patch | 1152 +---------------- .../xla/third_party/shardy/workspace.bzl | 4 +- 3 files changed, 9 insertions(+), 1151 deletions(-) diff --git a/third_party/xla/third_party/llvm/workspace.bzl b/third_party/xla/third_party/llvm/workspace.bzl index 5e3d8f2100a1be..dd3d4e4de4509d 100644 --- a/third_party/xla/third_party/llvm/workspace.bzl +++ b/third_party/xla/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "8dee997a8558b460b82b23fb43b197d68258baac" - LLVM_SHA256 = "6a26975000c2cb45787813317bfeeadeafa0cba762e9434fb7940481ec4b27de" + LLVM_COMMIT = "c6e23ab80753a01dce270f5f8a133fbec942315d" + LLVM_SHA256 = "5a6b8aacd2d87ce9c4456843a76d0a54fd7cd0ae788ed3f19e7487ecd2ce4326" tf_http_archive( name = name, diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index 1f51d21f432dd8..13d339429b0101 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,1157 +1,15 @@ -diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index f04aa96..509398d 100644 ---- a/third_party/llvm/generated.patch -+++ b/third_party/llvm/generated.patch -@@ -1,1137 +1 @@ - Auto generated patch. Do not edit or delete it, even if empty. --diff -ruN --strip-trailing-cr a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst ----- a/clang/docs/LanguageExtensions.rst --+++ b/clang/docs/LanguageExtensions.rst --@@ -1833,23 +1833,6 @@ -- -- Clang provides a few builtin aliases to improve the throughput of certain metaprogramming facilities. -- ---__builtin_common_reference ----------------------------- --- ---.. code-block:: c++ --- --- template