diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index b7e61115e37d6..748608005e622 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -125,10 +125,10 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks) + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 + GCC_VERSION=11 VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} @@ -146,16 +146,6 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) - CUDA_VERSION=12.8.1 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - ;; pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index b03606f6defc1..b660ae2dd3eaa 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -1b0418a9a454b2b93ab8d71f40e59d2297157fae +aa01f5c2cd4db2b7bfa53ea98a1a8dfbd6d77c92 diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index 22b7af890c1f6..a29de2cecb870 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -64,14 +64,13 @@ function install_ubuntu() { function install_rhel() { . /etc/os-release - if [[ "${ID}" == "rhel" ]]; then - if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then - echo "RHEL version ${VERSION_ID} not supported" - exit - fi - elif [[ "${ID}" == "almalinux" ]]; then - # Workaround for almalinux8 which used by quay.io/pypa/manylinux_2_28_x86_64 - VERSION_ID="8.8" + if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then + echo "RHEL version ${VERSION_ID} not supported" + exit + fi + # Using testing channel for CD build + if [[ "${ID}" == "almalinux" ]]; then + XPU_DRIVER_VERSION="/testing" fi dnf install -y 'dnf-command(config-manager)' diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index bdc34b4864cd7..044e1d09b54f0 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -397,3 +397,6 @@ scikit-build==0.18.1 pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI + +Jinja2==3.1.6 +#Description: required for torch.distributed.debug diff --git a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py index 224f078788702..aea27ca7dddae 100644 --- a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py +++ b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py @@ -84,7 +84,6 @@ def __init__(self, args: Any): self.VLLM_TEST_WHLS_REGEX = [ "xformers/*.whl", "vllm/vllm*.whl", - "flashinfer-python/flashinfer*.whl", ] def prepare(self): diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 01075259e9fe9..7e25c8c6d199c 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1763,12 +1763,14 @@ test_operator_microbenchmark() { mkdir -p "$TEST_REPORTS_DIR" TEST_DIR=$(pwd) + test_inductor_set_cpu_affinity + cd benchmarks/operator_benchmark/pt_extension - python -m pip install . + python -m pip install . -v --no-build-isolation cd "${TEST_DIR}"/benchmarks/operator_benchmark - for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do + for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv optimizer; do $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ --benchmark-name "PyTorch operator microbenchmark" --use-compile diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index c24a50b8b17ed..3771ecc108f87 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -31,23 +31,6 @@ if [[ "$PACKAGE_TYPE" != libtorch ]]; then export PATH="\${python_path}/bin:\$PATH" fi -EXTRA_CONDA_FLAGS="" -NUMPY_PIN="" -PROTOBUF_PACKAGE="defaults::protobuf" - -if [[ "\$python_nodot" = *310* ]]; then - # There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20 - # we set a lower boundary here just to be safe - NUMPY_PIN=">=1.21.2" - PROTOBUF_PACKAGE="protobuf>=3.19.0" -fi - -if [[ "\$python_nodot" = *39* ]]; then - # There's an issue with conda channel priority where it'll randomly pick 1.19 over 1.20 - # we set a lower boundary here just to be safe - NUMPY_PIN=">=1.20" -fi - # Move debug wheels out of the package dir so they don't get installed mkdir -p /tmp/debug_final_pkgs mv /final_pkgs/debug-*.zip /tmp/debug_final_pkgs || echo "no debug packages to move" @@ -66,12 +49,23 @@ fi if [[ "$PACKAGE_TYPE" != libtorch ]]; then if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}" - retry pip install -q numpy protobuf typing-extensions + + # numpy tests: + # We test 1 version no numpy. 1 version with numpy 1.x and rest with numpy 2.x + if [[ "\$python_nodot" = *311* ]]; then + retry pip install -q numpy==1.23.5 protobuf typing-extensions + elif [[ "\$python_nodot" = *312* ]]; then + retry pip install -q protobuf typing-extensions + else + retry pip install -q numpy protobuf typing-extensions + fi + else pip install "\$pkg" retry pip install -q numpy protobuf typing-extensions fi fi + if [[ "$PACKAGE_TYPE" == libtorch ]]; then pkg="\$(ls /final_pkgs/*-latest.zip)" unzip "\$pkg" -d /tmp diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 616dfd88ce812..b65b6a7f117ef 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -ee1a1350eb37804b94334768f328144f058f14e9 +32ce8c011855adb15438ddc9bf6c139d23f8cee5 diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 64ee992f566b7..c3b209c216014 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc +617079d944b0e72632311c30ae2bbdf1168b901e diff --git a/.github/ci_configs/vllm/Dockerfile b/.github/ci_configs/vllm/Dockerfile index a57793151de66..13fdb036abfe7 100644 --- a/.github/ci_configs/vllm/Dockerfile +++ b/.github/ci_configs/vllm/Dockerfile @@ -1,4 +1,4 @@ -ARG CUDA_VERSION=12.8.1 +ARG CUDA_VERSION=12.9.1 ARG PYTHON_VERSION=3.12 # BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine, @@ -124,7 +124,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' git clone https://github.com/facebookresearch/xformers.git pushd xformers - git checkout v0.0.32.post2 + git checkout v0.0.33.post1 git submodule update --init --recursive python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose popd @@ -256,7 +256,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" # Use copy mode to avoid hardlink failures with Docker cache mounts ENV UV_LINK_MODE=copy -# Install build and runtime dependencies, this is needed for flashinfer install +# Install build and runtime dependencies COPY requirements/build.txt requirements/build.txt COPY use_existing_torch.py use_existing_torch.py RUN python3 use_existing_torch.py @@ -294,33 +294,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system /wheels/xformers/*.whl --verbose -# Build FlashInfer from source -ARG torch_cuda_arch_list='8.0;8.9;9.0a;10.0a;12.0' -ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} - -# TODO(elainewy): remove this once vllm commit is updated, and install flashinfer from pip -# see https://github.com/pytorch/pytorch/pull/165274#issuecomment-3408531784 -ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -ARG FLASHINFER_GIT_REF="v0.2.14.post1" - -RUN --mount=type=cache,target=/root/.cache/uv \ - git clone --depth 1 --recursive --shallow-submodules \ - --branch ${FLASHINFER_GIT_REF} \ - ${FLASHINFER_GIT_REPO} flashinfer \ - && echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \ - && cd flashinfer \ - && python3 -m flashinfer.aot \ - && python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \ - && cd .. \ - && rm -rf flashinfer - -# Install FlashInfer -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system wheels/flashinfer/*.whl --verbose - # Logging to confirm the torch versions -RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer' -RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm\|^flashinfer' > build_summary.txt +RUN pip freeze | grep -E 'torch|xformers|vllm' +RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio\|^xformers\|^vllm' > build_summary.txt ################### VLLM INSTALLED IMAGE #################### @@ -331,4 +307,3 @@ FROM scratch as export-wheels COPY --from=base /workspace/xformers-dist /wheels/xformers COPY --from=build /workspace/vllm-dist /wheels/vllm COPY --from=vllm-base /workspace/build_summary.txt /wheels/build_summary.txt -COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index f7df4335cb5b6..d69db191b9464 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -50,6 +50,7 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { "12.6": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | " @@ -67,6 +68,7 @@ "nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux'" ), "12.8": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | " @@ -84,6 +86,7 @@ "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" ), "12.9": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " @@ -101,6 +104,7 @@ "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" ), "13.0": ( + "cuda-bindings==13.0.3; platform_system == 'Linux' | " "nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | " "nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | " "nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | " diff --git a/.github/scripts/prepare_vllm_wheels.sh b/.github/scripts/prepare_vllm_wheels.sh index 62362c7ff207c..0d56a4ef43273 100755 --- a/.github/scripts/prepare_vllm_wheels.sh +++ b/.github/scripts/prepare_vllm_wheels.sh @@ -88,7 +88,7 @@ repackage_wheel() { ${PYTHON_EXECUTABLE} -mpip install wheel==0.45.1 pushd externals/vllm/wheels -for package in xformers flashinfer-python vllm; do +for package in xformers vllm; do repackage_wheel $package done popd diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index b52ec158dd6d6..2434a595f5420 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -327,6 +327,7 @@ jobs: SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }} SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} + DOCKER_IMAGE_S390X: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} @@ -360,10 +361,12 @@ jobs: # if for some reason cleanup action doesn't stop container # when job is cancelled DOCKER_SHELL_CMD="sleep 12h" + USED_IMAGE="${DOCKER_IMAGE_S390X}" else SHM_OPTS="--shm-size=${SHM_SIZE}" JENKINS_USER="--user jenkins" DOCKER_SHELL_CMD= + USED_IMAGE="${DOCKER_IMAGE}" fi # detached container should get cleaned up by teardown_ec2_linux @@ -426,7 +429,7 @@ jobs: ${JENKINS_USER} \ -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ -w /var/lib/jenkins/workspace \ - "${DOCKER_IMAGE}" \ + "${USED_IMAGE}" \ ${DOCKER_SHELL_CMD} ) echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}" diff --git a/.github/workflows/attention_op_microbenchmark.yml b/.github/workflows/attention_op_microbenchmark.yml index e01bc49621dcf..eec4d21fe2616 100644 --- a/.github/workflows/attention_op_microbenchmark.yml +++ b/.github/workflows/attention_op_microbenchmark.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -39,7 +39,7 @@ jobs: needs: attn-microbenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }} secrets: inherit @@ -51,7 +51,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -66,7 +66,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/b200-distributed.yml b/.github/workflows/b200-distributed.yml index 596a31431e61b..bb85a4ddfc85e 100644 --- a/.github/workflows/b200-distributed.yml +++ b/.github/workflows/b200-distributed.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/b200-symm-mem.yml b/.github/workflows/b200-symm-mem.yml index 7fa8a8a730447..ba28066dd5602 100644 --- a/.github/workflows/b200-symm-mem.yml +++ b/.github/workflows/b200-symm-mem.yml @@ -37,7 +37,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 408a8f0000504..fa1f083800fe0 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -52,8 +52,7 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, pytorch-linux-jammy-py3.11-clang12, @@ -75,7 +74,8 @@ jobs: pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, - pytorch-linux-jammy-py3-clang12-executorch, + # TODO: Re-enable me when docker pin update happens + # pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu, pytorch-linux-noble-riscv64-py3.12-gcc14 ] diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 78d38de3ac69a..380b8c2d1e257 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -7,9 +7,18 @@ on: types: - completed workflow_dispatch: + inputs: + branch: + type: string + description: Branch corresponding to the docker images being cached + required: true + run_id: + type: string + description: Workflow run id to pull artifacts from + required: true concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }} + group: ${{ github.workflow }}-${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} cancel-in-progress: true permissions: @@ -30,7 +39,7 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v4.1.7 with: - run-id: ${{ github.event.workflow_run.id }} + run-id: ${{ github.event.workflow_run.id || github.event.inputs.run_id }} path: ./docker-builds-artifacts merge-multiple: true github-token: ${{ secrets.GITHUB_TOKEN }} @@ -51,8 +60,8 @@ jobs: runner: [linux.rocm.gfx942.docker-cache] docker-image: [ "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" + "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}" + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" ] runs-on: "${{ matrix.runner }}" steps: @@ -91,7 +100,7 @@ jobs: docker_image_tag=${{ matrix.docker-image }} docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":" docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-" - ref_name=${{ github.event.workflow_run.head_branch }} + ref_name=${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} if [[ $ref_name =~ "release/" ]]; then ref_suffix="release" elif [[ $ref_name == "main" ]]; then diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index b8a6403faffbd..6a22e14af09b7 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -132,7 +132,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -178,7 +178,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -224,7 +224,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -381,7 +381,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -427,7 +427,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -473,7 +473,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -519,7 +519,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -630,7 +630,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -676,7 +676,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -722,7 +722,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -768,7 +768,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -879,7 +879,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -925,7 +925,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -971,7 +971,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1017,7 +1017,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1128,7 +1128,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1174,7 +1174,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1220,7 +1220,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1266,7 +1266,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1377,7 +1377,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1423,7 +1423,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1469,7 +1469,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1515,7 +1515,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1626,7 +1626,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1672,7 +1672,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1718,7 +1718,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1764,7 +1764,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 21c1d5caa3829..a5f4e85ca58c1 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -127,7 +127,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_6-test: # Testing @@ -193,7 +193,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_8-test: # Testing @@ -259,7 +259,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_9-test: # Testing @@ -325,7 +325,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda13_0-test: # Testing @@ -793,7 +793,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_6-test: # Testing @@ -859,7 +859,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_8-test: # Testing @@ -925,7 +925,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_9-test: # Testing @@ -991,7 +991,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda13_0-test: # Testing @@ -1459,7 +1459,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_6-test: # Testing @@ -1525,7 +1525,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_8-test: # Testing @@ -1591,7 +1591,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_9-test: # Testing @@ -1657,7 +1657,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda13_0-test: # Testing @@ -2125,7 +2125,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_6-test: # Testing @@ -2191,7 +2191,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_8-test: # Testing @@ -2257,7 +2257,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_9-test: # Testing @@ -2323,7 +2323,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda13_0-test: # Testing @@ -2791,7 +2791,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_6-test: # Testing @@ -2857,7 +2857,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_8-test: # Testing @@ -2923,7 +2923,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_9-test: # Testing @@ -2989,7 +2989,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda13_0-test: # Testing @@ -3457,7 +3457,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_6-test: # Testing @@ -3523,7 +3523,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_8-test: # Testing @@ -3589,7 +3589,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_9-test: # Testing @@ -3655,7 +3655,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda13_0-test: # Testing @@ -4123,7 +4123,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_6-test: # Testing @@ -4189,7 +4189,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_8-test: # Testing @@ -4255,7 +4255,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_9-test: # Testing @@ -4321,7 +4321,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda13_0-test: # Testing diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index a0ae234ab5669..3421e2b9af77d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -30,14 +30,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -46,11 +46,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 628f624240127..764e631819ccc 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -27,14 +27,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -47,11 +47,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} # disable monitor in perf tests for more investigation diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml index 7b59e92386a33..11f5f10a55ad8 100644 --- a/.github/workflows/inductor-perf-test-b200.yml +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -80,7 +80,7 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -90,8 +90,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '10.0' test-matrix: | { include: [ @@ -104,12 +104,12 @@ jobs: secrets: inherit test-periodically: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -121,12 +121,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -138,11 +138,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 8209bf053a772..1c35fc6794537 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -95,8 +95,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '9.0' test-matrix: | { include: [ @@ -132,7 +132,7 @@ jobs: needs: build if: github.event.schedule == '15 0 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -149,7 +149,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -168,7 +168,7 @@ jobs: # needs one round of benchmark if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }} with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 19f72ba453414..88a528ba1b075 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -80,15 +80,15 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" # Every bit to make perf run faster helps runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -117,12 +117,12 @@ jobs: secrets: inherit test-nightly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -133,12 +133,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -150,12 +150,12 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index b08d9865d15d3..f3e34d6ecb52f 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -37,8 +37,8 @@ jobs: needs: get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0;8.6' test-matrix: | { include: [ @@ -76,7 +76,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-dynamo-benchmarks-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.periodic-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit @@ -138,8 +138,8 @@ jobs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -153,7 +153,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-smoke-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.inductor-smoke-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index dee10a0db3c16..57e5cb856729a 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -38,14 +38,14 @@ jobs: curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf - linux-jammy-rocm-py3_10-inductor-build: - name: rocm-py3.10-inductor-mi300 + linux-noble-rocm-py3_12-inductor-build: + name: rocm-py3.12-inductor-mi300 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, @@ -53,15 +53,15 @@ jobs: ]} secrets: inherit - linux-jammy-rocm-py3_10-inductor-test: + linux-noble-rocm-py3_12-inductor-test: permissions: id-token: write contents: read - name: rocm-py3.10-inductor-mi300 + name: rocm-py3.12-inductor-mi300 uses: ./.github/workflows/_rocm-test.yml - needs: linux-jammy-rocm-py3_10-inductor-build + needs: linux-noble-rocm-py3_12-inductor-build with: - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index ca9b57cab2ddb..0902026adb8ce 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -33,8 +33,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -52,7 +52,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 8a913c3b36a11..e524ed548b741 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -49,8 +49,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -69,7 +69,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index 89d6d63c72875..cd27b3a8a97db 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -18,14 +18,25 @@ permissions: contents: read jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + # H100 A100 runners opmicrobenchmark-build: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -41,7 +52,7 @@ jobs: needs: opmicrobenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} secrets: inherit @@ -51,9 +62,11 @@ jobs: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build-b200 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: - runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runner: linux.r7i.4xlarge + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -68,7 +81,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index ce68ee8bc8e03..f3356cfa4fc77 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -50,33 +50,33 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10-mi300 + linux-noble-rocm-py3_12-build: + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b", owners: ["module:rocm", "oncall:distributed"] }, ]} secrets: inherit - linux-jammy-rocm-py3_10-test: + linux-noble-rocm-py3_12-test: permissions: id-token: write contents: read - name: linux-jammy-rocm-py3.10-mi300 + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-jammy-rocm-py3_10-build + - linux-noble-rocm-py3_12-build - target-determination with: - build-environment: linux-jammy-rocm-py3.10-mi300 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} + build-environment: linux-noble-rocm-py3.12-mi300 + docker-image: ${{ needs.linux-noble-rocm-py3_12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-py3_12-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 5a90db9ab5737..325050392a393 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -90,6 +90,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: 8.6 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -97,7 +98,9 @@ jobs: { config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, + { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} secrets: inherit @@ -113,40 +116,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-build: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 - cuda-arch-list: 8.6 - test-matrix: | - { include: [ - { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-test: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-build - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.test-matrix }} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.9 test-matrix: | { include: [ @@ -160,16 +137,16 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-debug-test: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build + - linux-jammy-cuda12_8-py3_10-gcc11-debug-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.test-matrix }} secrets: inherit linux-jammy-cuda13_0-py3_10-gcc11-build: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 51e211a5ad2ad..f2483dff9a94c 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -318,14 +318,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-build: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '7.5' test-matrix: | { include: [ @@ -333,14 +333,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-test: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_10-gcc11-inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit linux-noble-xpu-n-py3_10-build: diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index d20b37be20876..99059a1ff857c 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -48,12 +48,12 @@ jobs: docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" }, ]} secrets: inherit diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index 07fd9b18fdada..7cc935f46d6c8 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -54,7 +54,7 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' @@ -75,4 +75,4 @@ jobs: docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only - secrets: inherit \ No newline at end of file + secrets: inherit diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 08fcd33402625..5a0273f0b745e 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -26,14 +26,14 @@ jobs: curr_ref_type: ${{ github.ref_type }} build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -42,11 +42,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 667c37727045b..d458bde5f9d30 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -203,9 +203,15 @@ jobs: sync-tag: rocm-build test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, - { config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4" }, ]} secrets: inherit @@ -223,7 +229,6 @@ jobs: build-environment: linux-jammy-rocm-py3.10 docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} - tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" secrets: inherit inductor-build: @@ -231,8 +236,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' secrets: inherit @@ -283,6 +288,7 @@ jobs: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build.yml needs: get-label-type + if: false # Has been broken for a while with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-executorch diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index b3d8073aad3b3..3a0567f33c8cc 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -18,6 +18,7 @@ on: - rocm-mi200 - rocm-mi300 - rocm-mi355 + - rocm-navi31 - inductor-micro-benchmark - inductor-micro-benchmark-x86 - inductor-cu124 diff --git a/.lintrunner.toml b/.lintrunner.toml index 7a6e241f90c8d..0f46b398ca501 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1751,3 +1751,50 @@ command = [ "python3", "tools/linter/adapters/gb_registry_linter.py", ] + +[[linter]] +code = 'PYLINT' +include_patterns = ['**/*.py'] +exclude_patterns = [ + '.git/**', + 'build_test_custom_build/**', + 'build/**', + 'caffe2/**', + 'docs/caffe2/**', + 'docs/cpp/src/**', + 'docs/src/**', + 'fb/**', + '**/fb/**', + 'functorch/docs/**', + 'functorch/examples/**', + 'functorch/docs/source/tutorials/**', + 'torch/_inductor/fx_passes/serialized_patterns/**', + 'torch/_inductor/autoheuristic/artifacts/**', + 'scripts/**', + 'test/generated_type_hints_smoketest.py', + 'test/test_torchfuzz_repros.py', + # CPython tests + 'test/dynamo/cpython/**', + # Tests from the NumPy test suite + 'test/torch_np/numpy_test/**/*.py', + 'third_party/**', + 'torch/include/**', + 'torch/lib/**', + 'venv/**', + '**/*.pyi', + "tools/experimental/torchfuzz/**", + 'tools/test/test_selective_build.py', +] +command = [ + 'python3', + 'tools/linter/adapters/pylint_linter.py', + '--config=pylintrc', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python3', + 'tools/linter/adapters/pip_init.py', + '--dry-run={{DRYRUN}}', + 'pylint==4.0.2', +] diff --git a/.spin/cmds.py b/.spin/cmds.py index a81717c7423be..374d7699ef7f8 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -191,6 +191,7 @@ def regenerate_clangtidy_files(): "FLAKE8", "GB_REGISTRY", "PYFMT", + "PYLINT", "PYREFLY", "TEST_DEVICE_BIAS", "TEST_HAS_MAIN", @@ -328,3 +329,10 @@ def quicklint(ctx, apply_patches, **kwargs): def quickfix(ctx, **kwargs): """Autofix changed files.""" ctx.invoke(quicklint, apply_patches=True) + + +@click.command() +def regenerate_github_workflows(): + """Regenerate GitHub workflows from templates.""" + cmd = [sys.executable, "scripts/generate_ci_workflows.py"] + spin.util.run(cmd, cwd="./.github") diff --git a/AGENTS.md b/AGENTS.md index 3d5436a02a85d..718217d3e663d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,6 +10,7 @@ - Do NOT run pre-commit, it is not setup - To run lint, run 'lintrunner -a' (which will autoapply changes) - Do NOT attempt to install dependencies, you do not have Internet access +- Do NOT create summary files unless explicitly asked - When you are ready to make a PR, do exactly these steps: - git stash -u - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e020abda3925..55bc4bde8a6e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,7 @@ endif() set(CPU_AARCH64 OFF) set(CPU_INTEL OFF) set(CPU_POWER OFF) +set(CPU_RISCV OFF) if(CMAKE_SYSTEM_PROCESSOR MATCHES "(AMD64|x86_64)") set(CPU_INTEL ON) @@ -178,6 +179,8 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)") set(CPU_AARCH64 ON) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(ppc64le)") set(CPU_POWER ON) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv64)") + set(CPU_RISCV ON) endif() # For non-supported platforms, turn USE_DISTRIBUTED off by default. It is not @@ -327,8 +330,8 @@ cmake_dependent_option(USE_ITT "Use Intel(R) VTune Profiler ITT functionality" # Ensure that an MKLDNN build is the default for x86 CPUs but optional for # AArch64 (dependent on -DUSE_MKLDNN). cmake_dependent_option( - USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, AArch64, and ppc64le." - "${CPU_INTEL}" "CPU_INTEL OR CPU_AARCH64 OR CPU_POWER" OFF) + USE_MKLDNN "Use MKLDNN. Only available on x86, x86_64, AArch64, ppc64le and riscv64." + "${CPU_INTEL}" "CPU_INTEL OR CPU_AARCH64 OR CPU_POWER OR CPU_RISCV" OFF) cmake_dependent_option( USE_MKLDNN_ACL "Use Compute Library for the Arm architecture." OFF "USE_MKLDNN AND CPU_AARCH64" OFF) @@ -1214,41 +1217,6 @@ else() append_cxx_flag_if_supported("/wd4273" CMAKE_CXX_FLAGS) endif() -if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") - include(CheckCSourceCompiles) - check_c_source_compiles( - "#include -int main() { - float a[] = {1.0, 1.0}; - float32x4x2_t v; - v.val[0] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); - v.val[1] = vcombine_f32 (vcreate_f32 (0UL), vcreate_f32 (0UL)); - vst1q_f32_x2(a, v); - return 0; -}" - HAS_VST1) - - if(NOT HAS_VST1) - string(APPEND CMAKE_CXX_FLAGS " -DMISSING_ARM_VST1") - endif() -endif() - -if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") - include(CheckCSourceCompiles) - check_c_source_compiles( - "#include -int main() { - float a[] = {1.0, 1.0}; - vld1q_f32_x2(a); - return 0; -}" - HAS_VLD1) - - if(NOT HAS_VLD1) - string(APPEND CMAKE_CXX_FLAGS " -DMISSING_ARM_VLD1") - endif() -endif() - # Add code coverage flags to supported compilers if(USE_CPP_CODE_COVERAGE) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/CODEOWNERS b/CODEOWNERS index 137031066090e..7516c4ad7ec06 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -135,7 +135,7 @@ torch/profiler/ @sraikund16 test/functorch/test_aotdispatch.py @ezyang @Chillee # Dataloader -torch/utils/data/ @divyanshk @ramanishsingh @scotts +torch/utils/data/ @divyanshk @ramanishsingh @scotts @aelavender # hipify torch/utils/hipify/ @jeffdaily @jithunnair-amd diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc0b0fc9bb00f..850753f13b63a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,6 +14,10 @@ aspects of contributing to PyTorch. - [Tips and Debugging](#tips-and-debugging) - [Nightly Checkout & Pull](#nightly-checkout--pull) - [Codebase structure](#codebase-structure) +- [Spin](#spin) + - [Linting](#linting) + - [default lint](#default-lint) + - [Regenerating](#regenerating) - [Unit testing](#unit-testing) - [Python Unit Testing](#python-unit-testing) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) @@ -274,6 +278,47 @@ dependencies as well as the nightly binaries into the repo directory. * ... * [.circleci](.circleci) - CircleCI configuration management. [README](.circleci/README.md) +## Spin + +[Spin](https://github.com/scientific-python/spin) is a developer cli tool that +helps running common tasks. +To list the available tasks, run `spin --help`. +Currently, we support the following tasks with Spin: + +### Linting + +Spin helps with linting by making sure that lintrunner is installed correctly +and by isolating the lintrunner environment from the general development +environment using uv. + +|command|| +|-|-| +|`setup-lint`|update lintrunner and perform a fresh setup| +|`lazy-setup-lint`|only perform setup if the lint configuration has changed| +|`lint`|perform default lint (see below)| +|`quicklint`|perform lint on all files changed in the latest commit and the working directory| +|`quickfix`|autofix issues on all files changed in the latest commit and the working directory| + +#### default lint + +Since some linters take a long time to run, we categorize all linters as either +fast or slow. In the default lint, only the fast linters are run on all files; +the slow linters are run on the changed files only. + +### Regenerating + +Pytorch makes use of a number of code generations, which range from the version +information in `torch/version.py` over type stubs and other linter support to +github workflows. +With Spin, we offer a unified interface to these tasks. + +|command|| +|-|-| +|`regenerate-version`|regenerate `torch/version.py`| +|`regenerate-type-stubs`|regenerates type stubs for use by static type checkers| +|`regenerate-clangtidy-files`|regenerates clang related files needed for linting| +|`regenerate-github-workflows`|regenerates github workflows from jinja templates| + ## Unit testing ### Python Unit Testing diff --git a/README.md b/README.md index a0c9b54c95a8b..c2f15d88738da 100644 --- a/README.md +++ b/README.md @@ -292,8 +292,13 @@ python tools/amd_build/build_amd.py Install PyTorch ```bash +# the CMake prefix for conda environment export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" python -m pip install --no-build-isolation -v -e . + +# the CMake prefix for non-conda environment, e.g. Python venv +# call following after activating the venv +export CMAKE_PREFIX_PATH="${VIRTUAL_ENV}:${CMAKE_PREFIX_PATH}" ``` **On macOS** diff --git a/aten/src/ATen/code_template.h b/aten/src/ATen/code_template.h index 2cde802dac172..edc1124240251 100644 --- a/aten/src/ATen/code_template.h +++ b/aten/src/ATen/code_template.h @@ -232,7 +232,7 @@ struct CodeTemplate { emitIndent(out, indent); emitStringWithIndents(out, indent, strings[i]); if (i + 1 != strings.size()) - out << "\n"; + out << '\n'; } } std::string template_text; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 46dc550b1f37b..35a729ccc9f39 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -680,7 +680,7 @@ TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) { return false; } if (elem_type->kind() == AnyType::Kind) { - // List of Any can contains heterogenous types + // List of Any can contains heterogeneous types return false; } return true; diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h deleted file mode 100644 index aa40000b6ccdb..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/missing_vld1_neon.h +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h b/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h deleted file mode 100644 index b3d721531d246..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/missing_vst1_neon.h +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 9a55b058001da..bc7607f232011 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1997,6 +1997,10 @@ void scaled_gemm( // Note: alpha_val may change later depending on user-passed argument float alpha_val = 1.0; float beta_val = 0.0; +#ifndef USE_ROCM + // Note: unused, but cublasLtMatmul requires a C pointer that is not result_ptr or nullptr + const void* dummy_C_ptr = mat1_ptr; +#endif // ifndef USE_ROCM CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); @@ -2180,8 +2184,11 @@ void scaled_gemm( mat2_ptr, Bdesc.descriptor(), beta_ptr, - // NOTE: always use result_ptr here, because cuBLASLt w/device beta=0 can't handle nullptr either +#ifdef USE_ROCM result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr +#else + dummy_C_ptr, // also unused, but cuBLAS can't use nullptr or result_ptr +#endif // ifdef USE_ROCM Cdesc.descriptor(), result_ptr, Ddesc.descriptor(), diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 81b4643ac0418..73340604574ad 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -3,252 +3,15 @@ #include #include #include -#include +#include #include -#include -#include - -#include - -#include -#include - -/* -* `cudaEventExternal` is a torch-specific flag that is used to -* indicate that the CUDAEvent will be used only for synchronization -* with work outside of the cuda graph, rather than creation of -* cross-stream dependencies within a cuda graph. Resources: -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 -* https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e -*/ -#define cudaEventExternal 0x08 namespace at::cuda { -/* -* CUDAEvents are movable not copyable wrappers around CUDA's events. -* -* CUDAEvents are constructed lazily when first recorded unless it is -* reconstructed from a cudaIpcEventHandle_t. The event has a device, and this -* device is acquired from the first recording stream. However, if reconstructed -* from a handle, the device should be explicitly specified; or if ipc_handle() is -* called before the event is ever recorded, it will use the current device. -* Later streams that record the event must match this device. -*/ -struct TORCH_CUDA_CPP_API CUDAEvent { - // Constructors - // Default value for `flags` is specified below - it's cudaEventDisableTiming - CUDAEvent() noexcept = default; - CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} - - CUDAEvent( - DeviceIndex device_index, const cudaIpcEventHandle_t* handle) : device_index_(device_index) { - CUDAGuard guard(device_index_); - - AT_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); - is_created_ = true; - } - - // Note: event destruction done on creating device to avoid creating a - // CUDA context on other devices. - ~CUDAEvent() { - try { - if (is_created_) { - CUDAGuard guard(device_index_); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_deletion(at::kCUDA, reinterpret_cast(event_)); - } - AT_CUDA_CHECK(cudaEventDestroy(event_)); - } - } catch (...) { /* No throw */ } - } - - CUDAEvent(const CUDAEvent&) = delete; - CUDAEvent& operator=(const CUDAEvent&) = delete; - - CUDAEvent(CUDAEvent&& other) noexcept { moveHelper(std::move(other)); } - CUDAEvent& operator=(CUDAEvent&& other) noexcept { - if (this != &other) { - moveHelper(std::move(other)); - } - return *this; - } - - operator cudaEvent_t() const { return event(); } - - // Less than operator (to allow use in sets) - friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { - return left.event_ < right.event_; - } - - std::optional device() const { - if (is_created_) { - return at::Device(at::kCUDA, device_index_); - } else { - return {}; - } - } - - bool isCreated() const { return is_created_; } - DeviceIndex device_index() const {return device_index_;} - cudaEvent_t event() const { return event_; } - - // Note: cudaEventQuery can be safely called from any device - bool query() const { - if (!is_created_) { - return true; - } - - cudaError_t err = cudaEventQuery(event_); - if (err == cudaSuccess) { - return true; - } else if (err != cudaErrorNotReady) { - C10_CUDA_CHECK(err); - } else { - // ignore and clear the error if not ready - (void)cudaGetLastError(); - } - - return false; - } - - void record() { record(getCurrentCUDAStream()); } - - void recordOnce(const CUDAStream& stream) { - if (!was_recorded_) record(stream); - } - - // Note: cudaEventRecord must be called on the same device as the event. - void record(const CUDAStream& stream) { - if (!is_created_) { - createEvent(stream.device_index()); - } - - TORCH_CHECK(device_index_ == stream.device_index(), "Event device ", device_index_, - " does not match recording stream's device ", stream.device_index(), "."); - CUDAGuard guard(device_index_); - -#ifndef USE_ROCM - // it is an error to use cudaEventRecordExternal when not doing stream capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventRecordExternal : cudaEventRecordDefault; - AT_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); -#else - AT_CUDA_CHECK(cudaEventRecord(event_, stream)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_record(at::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream()) - ); - } - was_recorded_ = true; - } - - // Note: cudaStreamWaitEvent must be called on the same device as the stream. - // The event has no actual GPU resources associated with it. - void block(const CUDAStream& stream) { - if (is_created_) { - CUDAGuard guard(stream.device_index()); -#ifndef USE_ROCM - // it is an error to use cudaEventWaitExternal when not doing stream capture - unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != c10::cuda::CaptureStatus::None && external_) ? cudaEventWaitExternal : cudaEventWaitDefault; - AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); -#else - AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); -#endif - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_wait(at::kCUDA, - reinterpret_cast(event_), - reinterpret_cast(stream.stream()) - ); - } - } - } - - // Note: cudaEventElapsedTime can be safely called from any device - float elapsed_time(const CUDAEvent& other) const { - TORCH_CHECK_VALUE( - !(flags_ & cudaEventDisableTiming) && !(other.flags_ & cudaEventDisableTiming), - "Both events must be created with argument 'enable_timing=True'."); - TORCH_CHECK_VALUE( - is_created_ && other.isCreated(), - "Both events must be recorded before calculating elapsed time."); - TORCH_CHECK( - query() && other.query(), - "Both events must be completed before calculating elapsed time."); - - float time_ms = 0; - // We do not strictly have to set the device index to the same as our event, - // but if we don't and the current device is not initialized, it will - // create a new cuda context, which will consume a lot of memory. - CUDAGuard guard(device_index_); - // raise cudaErrorNotReady if either event is recorded but not yet completed - AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); - return time_ms; - } - - // Note: cudaEventSynchronize can be safely called from any device - void synchronize() const { - if (is_created_) { - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_synchronization(at::kCUDA, reinterpret_cast(event_)); - } - AT_CUDA_CHECK(cudaEventSynchronize(event_)); - } - } - - // Note: cudaIpcGetEventHandle must be called on the same device as the event - void ipc_handle(cudaIpcEventHandle_t * handle) { - if (!is_created_) { - // this CUDAEvent object was initially constructed from flags but event_ - // is not created yet. - createEvent(getCurrentCUDAStream().device_index()); - } - CUDAGuard guard(device_index_); - AT_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); - } - -private: - unsigned int flags_ = cudaEventDisableTiming; - bool is_created_ = false; - bool was_recorded_ = false; - bool external_ = false; - DeviceIndex device_index_ = -1; - cudaEvent_t event_{}; - - void createEvent(DeviceIndex device_index) { - external_ = (flags_ & cudaEventExternal) != 0; -#ifdef USE_ROCM - TORCH_CHECK(!external_, "External events are disallowed in rocm"); -#endif - flags_ &= ~cudaEventExternal; - device_index_ = device_index; - CUDAGuard guard(device_index_); - AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); - const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); - if (C10_UNLIKELY(interp)) { - (*interp)->trace_gpu_event_creation(at::kCUDA, reinterpret_cast(event_)); - } - is_created_ = true; - } - - void moveHelper(CUDAEvent&& other) { - std::swap(flags_, other.flags_); - std::swap(is_created_, other.is_created_); - std::swap(was_recorded_, other.was_recorded_); - std::swap(device_index_, other.device_index_); - std::swap(event_, other.event_); - } -}; - -// EventPool - Thread-safe pool of CUDA events to avoid expensive cudaEventCreate -// calls. cudaEventCreate when concurrently invoked from multiple threads can be -// very expensive (especially on certain device/driver combinations). +// EventPool - Thread-safe pool of CUDA events to avoid expensive +// cudaEventCreate calls. cudaEventCreate when concurrently invoked from +// multiple threads can be very expensive (especially on certain device/driver +// combinations). using CUDAEventPtr = std::unique_ptr>; diff --git a/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h new file mode 100644 index 0000000000000..f2741a32889fb --- /dev/null +++ b/aten/src/ATen/hip/impl/HIPEventMasqueradingAsCUDA.h @@ -0,0 +1,86 @@ +#pragma once + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// See Note [Masquerading as CUDA] for motivation + +struct HIPEventMasqueradingAsCUDA { + HIPEventMasqueradingAsCUDA() noexcept = default; + HIPEventMasqueradingAsCUDA(unsigned int flags) noexcept + : event_(HIPEvent(flags)) {} + HIPEventMasqueradingAsCUDA( + DeviceIndex device_index, + const hipIpcEventHandle_t* handle) + : event_(HIPEvent(device_index, handle)) {} + + ~HIPEventMasqueradingAsCUDA() = default; + + HIPEventMasqueradingAsCUDA(const HIPEventMasqueradingAsCUDA&) = delete; + HIPEventMasqueradingAsCUDA& operator=(const HIPEventMasqueradingAsCUDA&) = delete; + HIPEventMasqueradingAsCUDA(HIPEventMasqueradingAsCUDA&& other) noexcept = default; + HIPEventMasqueradingAsCUDA& operator=(HIPEventMasqueradingAsCUDA&& other) noexcept = default; + + operator hipEvent_t() const { + return event_.event(); + } + + // Less than operator (to allow use in sets) + friend bool operator<( + const HIPEventMasqueradingAsCUDA& left, + const HIPEventMasqueradingAsCUDA& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + // Unsafely coerce HIP device into CUDA device + return Device(c10::DeviceType::CUDA, event_.device_index()); + } + bool isCreated() const { + return event_.isCreated(); + } + DeviceIndex device_index() const { + return event_.device_index(); + } + hipEvent_t event() const { + return event_.event(); + } + bool query() const { + return event_.query(); + } + void record() { + return event_.record(); + } + + void recordOnce(const HIPStreamMasqueradingAsCUDA& stream) { + event_.recordOnce(stream.hip_stream()); + } + + void record(const HIPStreamMasqueradingAsCUDA& stream) { + event_.record(stream.hip_stream()); + } + + void block(const HIPStreamMasqueradingAsCUDA& stream) { + event_.block(stream.hip_stream()); + } + + float elapsed_time(const HIPEventMasqueradingAsCUDA& other) const { + return event_.elapsed_time(other.event_); + } + + void synchronize() const { + event_.synchronize(); + } + + void ipc_handle(hipIpcEventHandle_t* handle) { + event_.ipc_handle(handle); + } + + private: + HIPEvent event_; +}; + +}} // namespace c10::hip diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp index 3fe27c7a0825b..7b3d790f067b5 100644 --- a/aten/src/ATen/miopen/Descriptors.cpp +++ b/aten/src/ATen/miopen/Descriptors.cpp @@ -114,8 +114,8 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo // that is the common case, so we can catch most client errors with this test. TORCH_CHECK(t.is_contiguous(memory_format), "MIOpen filters (a.k.a. weights) must be contiguous in desired memory_format\n", - "Weight sizes: ", t.sizes(), "\n", - "Weight strides: ", t.strides(), "\n", + "Weight sizes: ", t.sizes(), '\n', + "Weight strides: ", t.strides(), '\n', "cuDNN suggested memory_format: ", memory_format); int size[MIOPEN_DIM_MAX]; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 07bdc19ec8ff7..2cc7cf913cdcb 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2909,13 +2909,26 @@ Tensor linalg_matrix_norm( // Check A, dim, and dtype _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/abs_ord != 2.); - auto max_min = [ord, keepdim](const Tensor& A, int64_t dim) { return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); }; + auto max_min_wrapper = [ord, keepdim](const Tensor &A, int64_t dim) { + if (A.size(dim) == 0 && ord > 0) { + auto new_shape(DimVector(A.sizes())); + auto dim_ = maybe_wrap_dim(dim, A.dim()); + if (keepdim) { + new_shape[dim_] = 1; + } else { + new_shape.erase(std::begin(new_shape) + dim_); + } + return at::zeros(new_shape, A.options()); + } else { + return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); + } + }; if (abs_ord == 2.) { // Move dims to the end auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A.dim()); auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A; - auto result = max_min(at::linalg_svdvals(A_.permute(permutation)), -1); + auto result = max_min_wrapper(at::linalg_svdvals(A_.permute(permutation)), -1); if (keepdim) { auto permutation_reverse = create_reverse_permutation(std::move(permutation)); result = result.unsqueeze(-1).permute(permutation_reverse); @@ -2932,7 +2945,7 @@ Tensor linalg_matrix_norm( if (!keepdim && (dim_[0] < dim_[1])) { dim_[1]--; } - return max_min(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]); + return max_min_wrapper(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]); } } @@ -3541,9 +3554,9 @@ Tensor _dyn_quant_matmul_4bit_cpu( const int64_t out_features) { auto M = inp.size(0); TORCH_CHECK( - inp.dtype() == kFloat, + inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features), __func__, - " : expect input to be 32-bit float tensor."); + " : expect input to be float32 or bfloat16 tensor."); TORCH_CHECK( block_size == in_features || (!(block_size % 32) && !(in_features % block_size)), diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 221f621ea1e06..26ec55c11d823 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -353,6 +353,9 @@ void remainder_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "remainder_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t { TORCH_CHECK(b != 0, "ZeroDivisionError"); + if (a == std::numeric_limits::min() && b == scalar_t(-1)) { + return 0; + } scalar_t r = a % b; if ((r != 0) && (c10::is_negative(r) != c10::is_negative(b))) { r += b; @@ -813,8 +816,43 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) { } void huber_kernel(TensorIterator& iter, double delta) { - AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() { + // Special-case kHalf: compute in float for numerical stability + if (iter.dtype() == kHalf) { + const float delta_val(static_cast(delta)); + const Vectorized delta_vec(static_cast(delta)); + const Vectorized point_five_vec(static_cast(0.5)); + cpu_kernel_vec( + iter, + // scalar lambda: convert half -> float, compute in float, cast back to half + [&delta_val] (at::Half a, at::Half b) -> at::Half { + float af = static_cast(a); + float bf = static_cast(b); + float z = std::abs(af - bf); + float out = z < delta_val + ? 0.5f * z * z + : delta_val * (z - 0.5f * delta_val); + return static_cast(out); + }, + [&delta_vec, &point_five_vec] (Vectorized a, Vectorized b) { + auto [a0, a1] = convert_half_float(a); + auto [b0, b1] = convert_half_float(b); + auto z = (a0 - b0).abs(); + a0 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + z = (a1 - b1).abs(); + a1 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + return convert_float_half(a0, a1); + } + ); + return; + } + else { + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() { using Vec = Vectorized; const scalar_t delta_val(delta); const Vec delta_val_vec(delta_val); @@ -835,6 +873,7 @@ void huber_kernel(TensorIterator& iter, double delta) { z >= delta_val_vec); }); }); + } } void sigmoid_backward_kernel(TensorIteratorBase& iter) { @@ -999,6 +1038,9 @@ void fmod_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "fmod_cpu", [&]() { cpu_kernel(iter, [=](scalar_t x, scalar_t d) -> scalar_t { TORCH_CHECK(d != 0, "ZeroDivisionError"); + if (x == std::numeric_limits::min() && d == scalar_t(-1)) { + return 0; + } return x % d; }); }); diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index 33aae4fbf27a5..1ffaa7bcd90b7 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -793,6 +794,139 @@ bool can_use_kleidiai( } #endif +static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16( + size_t m, + size_t n, + size_t k, + const uint16_t* lhs_bf16, + const uint8_t* rhs_qs4cx, + const float* rhs_scales, + uint16_t* dst_bf16, + float scalar_min, + float scalar_max, + const float* bias) { + // Roundup lambda for internal stride calculations + auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; }; + + // Cast bfloat16 to float32 inline + auto cast_bf16_to_f32 = [](uint16_t bf16_val) { + uint32_t tmp = static_cast(bf16_val) << 16; + float f; + std::memcpy(&f, &tmp, sizeof(f)); + return f; + }; + + // Cast float32 to bfloat16 inline + auto cast_f32_to_bf16 = [](float f) { + uint32_t bits; + std::memcpy(&bits, &f, sizeof(bits)); + return static_cast(bits >> 16); + }; + + // Quantization pack lambda (channelwise QA8DX) + auto quant_pack_8bit_channelwise = + [&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + + const size_t dst_stride = + K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + for (size_t i = 0; i < M; ++i) { + const uint16_t* row_ptr = src_bf16 + i * K; + // find min/max + float mn = FLT_MAX, mx = -FLT_MAX; + for (size_t j = 0; j < K; ++j) { + float v = cast_bf16_to_f32(row_ptr[j]); + mn = std::min(mn, v); + mx = std::max(mx, v); + } + float rmin = std::min(0.0f, mn); + float rmax = std::max(0.0f, mx); + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); + float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin); + float recip = scale ? 1.0f / scale : 0.0f; + int32_t zp; + float des_min = rmin * scale; + float des_max = rmax * scale; + float err_min = qmin + des_min; + float err_max = qmax + des_max; + float zp_f = + (err_min + err_max) > 0 ? qmin - des_min : qmax - des_max; + zp_f = std::clamp(zp_f, qmin, qmax); + zp = std::lrintf(zp_f); + int8_t* out_ptr = dst_qa8dx + i * dst_stride; + // store header + *reinterpret_cast(out_ptr) = recip; + *reinterpret_cast(out_ptr + sizeof(float)) = -zp; + out_ptr += sizeof(float) + sizeof(int32_t); + // quantize + for (size_t j = 0; j < K; ++j) { + float v = cast_bf16_to_f32(row_ptr[j]); + int32_t q = static_cast(std::round(v * scale)) + zp; + q = std::clamp( + q, static_cast(kI8Min), static_cast(kI8Max)); + *out_ptr++ = static_cast(q); + } + } + }; + + // MatMul lambda (MXN x MXK -> MNXK BF16) + auto matmul_kernel = [&](size_t M, + size_t N, + size_t K, + const int8_t* lhs, + const uint8_t* rhs, + const float* scales, + uint16_t* dst, + float lo, + float hi) { + const size_t lhs_stride = + K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = roundup(K, 2) / 2; + for (size_t i = 0; i < M; ++i) { + const int8_t* lhs_row = lhs + i * lhs_stride; + for (size_t j = 0; j < N; ++j) { + int32_t acc = 0; + const int8_t* lptr = lhs_row; + const uint8_t* rptr = rhs + j * rhs_stride; + float lhs_scale = *reinterpret_cast(lptr); + int32_t lhs_off = + *reinterpret_cast(lptr + sizeof(float)); + lptr += sizeof(float) + sizeof(int32_t); + for (size_t t = 0; t < K; ++t) { + int32_t lv = static_cast(lptr[t]); + uint8_t bv = rptr[t / 2]; + int32_t rv = ((t & 1) == 0) ? (static_cast(bv & 0xF) - 8) + : (static_cast(bv >> 4) - 8); + acc += lv * rv + lhs_off * rv; + } + float res = static_cast(acc) * scales[j] * lhs_scale; + if (bias) { + res += bias[j]; + } + res = std::clamp(res, lo, hi); + *dst++ = cast_f32_to_bf16(res); + } + } + }; + + // allocate and run + std::unique_ptr packed( + new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]); + quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get()); + matmul_kernel( + m, + n, + k, + packed.get(), + rhs_qs4cx, + rhs_scales, + dst_bf16, + scalar_min, + scalar_max); +} + /** * The Int4 quantized weights must be represented as a uint8 tensor * For matrix multiplication with a weight shape of (N x K) @@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel( #if AT_KLEIDIAI_ENABLED() if (can_use_kleidiai(scales_zeros, K, block_size)) { const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type()); packed_weights.resize_({weight_packed_size}); kleidiai::kai_pack_int4_rhs( packed_weights, weights, scales_zeros, bias, N, K, block_size); } else #endif { - TORCH_CHECK( - bias.has_value() == 0, - __func__, - " : Bias is unsupported in reference implementation"); packed_weights = packed_weights.to(kFloat); - auto weight_reshaped = weights.view({-1}).to(kFloat); - auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat); - auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0); + auto weight_reshaped = weights.reshape({-1}).to(kFloat); + auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat); + std::vector tensors_to_cat = {weight_reshaped, scales_zeros_reshaped}; + if (bias.has_value()) { + tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat)); + } + auto res = at::cat(tensors_to_cat, 0); packed_weights.resize_(res.sizes()).copy_(res); } } @@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( const float* rhs_scales_f32, float* dst_f32, float scalar_min, - float scalar_max) { + float scalar_max, + const float* bias) { const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float)); auto lhs_qa8dx_buffer = std::make_unique(input_size_8bit); @@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( // required format for matmul auto input_quant_pack_8bit_channelwise = [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( } // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); const float rmin0 = std::min(0.0f, min0); const float rmax0 = std::max(0.0f, max0); @@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( zero_point0 = std::min(zero_point0, qmax); // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point0 = std::lrintf(zero_point0); int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride; @@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = std::max(v0_s32, static_cast(INT8_MIN)); - v0_s32 = std::min(v0_s32, static_cast(INT8_MAX)); + v0_s32 = std::max(v0_s32, static_cast(kI8Min)); + v0_s32 = std::min(v0_s32, static_cast(kI8Max)); dst_ptr[0] = (int8_t)v0_s32; dst_ptr += sizeof(int8_t); } @@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( main_acc = main_acc * lhs_scale; + if (bias) { + main_acc += bias[n_idx]; + } + // Clamp (min-max) operation main_acc = std::max(main_acc, scalar_min); main_acc = std::min(main_acc, scalar_max); @@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( const float* rhs_scales_fp32, float* dst_f32, float scalar_min, - float scalar_max) { + float scalar_max, + const float* bias) { // Lambda for LHS quantization auto lhs_quant_pack = [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + constexpr int8_t kI8Min = std::numeric_limits::lowest(); + constexpr int8_t kI8Max = std::numeric_limits::max(); + const size_t dst_stride = (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); @@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( min0 = std::min(src0_0, min0); } - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; + constexpr float qmin = static_cast(kI8Min); + constexpr float qmax = static_cast(kI8Max); const float rmin0 = std::min(0.0f, min0); const float rmax0 = std::max(0.0f, max0); @@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( zero_point0 = std::max(zero_point0, qmin); zero_point0 = std::min(zero_point0, qmax); - const int32_t nudged_zero_point0 = lrintf(zero_point0); + const int32_t nudged_zero_point0 = std::lrintf(zero_point0); int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride; @@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( const float src0_0 = src_ptr[k_idx]; int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); v0_s32 = std::max( - std::min( - v0_s32 + nudged_zero_point0, static_cast(INT8_MAX)), - static_cast(INT8_MIN)); + std::min(v0_s32 + nudged_zero_point0, static_cast(kI8Max)), + static_cast(kI8Min)); dst_ptr[0] = (int8_t)v0_s32; dst_ptr += sizeof(int8_t); } @@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( } main_acc = main_acc * lhs_scale; + + if (bias) { + main_acc += bias[col_idx]; + } + main_acc = std::max(main_acc, scalar_min); main_acc = std::min(main_acc, scalar_max); @@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( } /** - * Dynamic Input Quant 4 bit weights matmul execution flow - (INT4 Weights + FP scales + FP32 Bias) - FP32 Input Packed Buffer - | | - Quantize Cast - to INT8 to INT8 - | | - v v - INT8 Input INT8 Weights - \ / - \ / - \ / - INT8 Matrix Multiplication - | - v - FP32 Dequantized and Accumulate in FP32 - | - v - FP32 Final Output - - * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires - * Float32 Scales. If not provided, we will use fallback implementation. + * Dynamic INT4 weight-only MatMul with per-row input quantization. + * + * Execution Flow: + * + * (INT4 Weights + FP Scales [+ optional Bias]) + * + * Input (FP32 or BF16) Packed Weight Buffer + * | | + * Row-wise Quantization (INT8) | + * | | + * INT8 Input Activation INT4 Quantized Weights + Scales + * \ / + * \ / + * Quantized Matrix Multiply + * | + * Output Tensor (BF16 or FP32) + * + * Notes: + * - Groupwise kernels expect BF16 scales + * - Channelwise kernels expect FP32 scales + * - Bias is currently unsupported in fallback path */ void dyn_quant_matmul_4bit_kernel( const Tensor& output, @@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel( const int64_t block_size) { #if AT_KLEIDIAI_ENABLED() const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type()); if (weight_packed_size == packed_weights.numel()) { // KleidiAI interface internally handles the Channelwise and groupwise // distinction - kleidiai::kai_quant_pack_lhs_int4_mm( - output, inp, packed_weights, M, N, K, block_size); + kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size); } else #endif { - float* lhs_f32 = reinterpret_cast(inp.data_ptr()); - const auto weights_size = N * K / 2; - // The weights needs to be in uint8_t data type after quantization - auto extracted_weights = - (packed_weights.narrow(0, 0, weights_size)).to(kByte); - auto float32_scales = - (packed_weights.narrow( - 0, weights_size, packed_weights.size(0) - weights_size)) - .to(kFloat); - uint8_t* rhs_4bit = - reinterpret_cast(extracted_weights.data_ptr()); - float* rhs_scales_f32 = reinterpret_cast(float32_scales.data_ptr()); - float* dst_f32 = reinterpret_cast(output.data_ptr()); - if (block_size == K) { - ref_dyn_quant_matmul_4bit_channelwise_kernel( - M, - N, - K, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); - } else if (!(block_size % 32) && !(K % block_size)) { - ref_dyn_quant_matmul_4bit_groupwise_kernel( - M, - N, - K, - block_size, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); + { + void* input = inp.data_ptr(); + void* dst = output.data_ptr(); + + // Extract weights, sclaes and biases form from packed tensor + const int weights_elements = N * K / 2; + const int scale_elements = N * (K / block_size); + TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size"); + + auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte); + auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat); + auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements); + + int bias_elements = packed_weights.numel() - (weights_elements + scale_elements); + float* weight_scales = float32_scales.data_ptr(); + + void* bias_data = nullptr; + if (bias_elements) { + auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements); + TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension"); + bias_data = float32_bias.data_ptr(); + + } + // 2 elements of 4 bit weights are packed into 1 uint8 packet + uint8_t* weights_4bit = reinterpret_cast(extracted_weights.data_ptr()); + + // Dispatch to reference kernels + if (inp.scalar_type() == at::kBFloat16) { + // BF16 input, BF16 output + constexpr float BF16_MAX = 3.38953139e+38f; + constexpr float BF16_MIN = -BF16_MAX; + if (block_size == K) { + ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16( + M, N, K, + (uint16_t*)input, weights_4bit, weight_scales, + (uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data); + } else { + TORCH_CHECK(false, "Unsupported block size for BF16 fallback"); + } + } else if (inp.scalar_type() == at::kFloat) { + // FP32 input, FP32 output + if (block_size == K) { + ref_dyn_quant_matmul_4bit_channelwise_kernel( + M, N, K, + (float*)input, weights_4bit, weight_scales, + (float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data); + } else if (!(block_size % 32) && !(K % block_size)) { + ref_dyn_quant_matmul_4bit_groupwise_kernel( + M, N, K, block_size, + (float*)input, weights_4bit, weight_scales, + (float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data); + } else { + TORCH_CHECK(false, "Unsupported block size for FP32 fallback"); + } } else { - TORCH_CHECK( - block_size == K || (!(block_size % 32) && !(K % block_size)), - __func__, - ": Group size should be multiple 32 or in_features [", - K, - "]. Provided ", - block_size); + TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel"); } - } } - +} } // anonymous namespace - +} ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel) ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel) REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel) diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h index 8aba425e89637..8fa84b4445798 100644 --- a/aten/src/ATen/native/cpu/moments_utils.h +++ b/aten/src/ATen/native/cpu/moments_utils.h @@ -46,8 +46,11 @@ C10_ALWAYS_INLINE void AddMomentsVec( const T c = n == 0 ? static_cast(0) : static_cast(m0_add) / static_cast(n); const Vec c_vec(c); const Vec delta = m1_add - m1; - m1 += c_vec * delta; - m2 += m2_add + delta * delta * c_vec * Vec(static_cast(m0)); + const Vec m2_tmp = m2 + m2_add; + const Vec c_vec_delta = c_vec * delta; + const Vec m0_delta = delta * Vec(static_cast(m0)); + m1 = m1 + c_vec_delta; + m2 = fmadd(m0_delta, c_vec_delta, m2_tmp); m0 = n; } @@ -65,9 +68,11 @@ UpdateMomentsVec( Vec m2_vec(0); for (const auto j : c10::irange(m0)) { const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size()); + const Vec tmpVec = c_vecs[j]; const Vec delta_vec = x_vec - m1_vec; - m1_vec += delta_vec * c_vecs[j]; - m2_vec += delta_vec * (x_vec - m1_vec); + m1_vec = fmadd(tmpVec, delta_vec, m1_vec); + const Vec tmpVec2 = x_vec - m1_vec; + m2_vec = fmadd(delta_vec, tmpVec2, m2_vec); } AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0); } @@ -89,13 +94,16 @@ UpdateMomentsVec( fVec m2_fvec0(0), m2_fvec1(0); for (const auto j : c10::irange(m0)) { const Vec x_bvec = Vec::loadu(X_ptr + j * Vec::size()); + const fVec tmpVec = c_vecs[j]; auto [x_fvec0, x_fvec1] = convert_to_float(x_bvec); const fVec delta_fvec0 = x_fvec0 - m1_fvec0; const fVec delta_fvec1 = x_fvec1 - m1_fvec1; - m1_fvec0 += delta_fvec0 * c_vecs[j]; - m1_fvec1 += delta_fvec1 * c_vecs[j]; - m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0); - m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1); + m1_fvec0 = fmadd(delta_fvec0, tmpVec, m1_fvec0); + m1_fvec1 = fmadd(delta_fvec1, tmpVec, m1_fvec1); + const fVec delta_fvec2 = x_fvec0 - m1_fvec0; + const fVec delta_fvec3 = x_fvec1 - m1_fvec1; + m2_fvec0 = fmadd(delta_fvec0, delta_fvec2, m2_fvec0); + m2_fvec1 = fmadd(delta_fvec1, delta_fvec3, m2_fvec1); } AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0); AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0); diff --git a/aten/src/ATen/native/cuda/ActivationEluKernel.cu b/aten/src/ATen/native/cuda/ActivationEluKernel.cu index 5ad1f806f9ba5..9fc29aa5539b5 100644 --- a/aten/src/ATen/native/cuda/ActivationEluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationEluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu index cd5a0ae85e61c..87781c44e3348 100644 --- a/aten/src/ATen/native/cuda/ActivationGeluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGeluKernel.cu @@ -5,7 +5,6 @@ #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationGluKernel.cu b/aten/src/ATen/native/cuda/ActivationGluKernel.cu index e28a6d61ea152..8a782a129c9fb 100644 --- a/aten/src/ATen/native/cuda/ActivationGluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationGluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu index 2a0be3f5d27bf..f0968b957aa6d 100644 --- a/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu index fcacef37ceaf0..813a8c07ccfac 100644 --- a/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardsigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu index 1642d0909f7f0..651cdef82543b 100644 --- a/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardswishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu index a18072f7a27bc..85aa7ccd22a9e 100644 --- a/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationHardtanhKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu index 72130739898fe..340a6f97d00de 100644 --- a/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLeakyReluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu index 9a1d672428b48..2175920917852 100644 --- a/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationLogSigmoidKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationMishKernel.cu b/aten/src/ATen/native/cuda/ActivationMishKernel.cu index 0db0e96bb180a..25ba9810e37cf 100644 --- a/aten/src/ATen/native/cuda/ActivationMishKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationMishKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu index f7ddfd8502a18..ebdfe245b6166 100644 --- a/aten/src/ATen/native/cuda/ActivationSiluKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSiluKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu index 64ffc21123707..65f4f3679f862 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftplusKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu index 0c2dc63dbcf45..712c86e0e5216 100644 --- a/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationSoftshrinkKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu index 2d1cb4a47d7d8..430f9cbfa78bb 100644 --- a/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu +++ b/aten/src/ATen/native/cuda/ActivationThresholdKernel.cu @@ -5,8 +5,6 @@ #include -#include - #include #include #include diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 0d34bd52f211a..169a2ab92615f 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -29,7 +29,7 @@ Scalar _local_scalar_dense_cuda(const Tensor& self) { std::nullopt /* memory format */ ); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - at::cuda::memcpy_and_sync((void *)value.const_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); + at::cuda::memcpy_and_sync(value.mutable_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(*value.const_data_ptr()); }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return r; diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index a917b0d6163fa..3f4f998d92cd6 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -346,8 +346,9 @@ void dispatch_bf16_grouped_kernel_on_tile_size( bool small = (M <= 128 || N <= 128); cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; - if (sm10x) { + if (sm10x || sm11x) { if (small){ bf16bf16_grouped_gemm_impl_sm90_sm100< cutlass::arch::Sm100, diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index fd406829707a1..5c8b98105bb26 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -5,11 +5,69 @@ #include #endif +// ROCm 6.3 is planned to have these functions, but until then here they are. #if defined(USE_ROCM) #include #include #include -#define ATOMICADD unsafeAtomicAdd + +__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); + union { + __hip_bfloat162_raw bf162_raw; + vec_short2 vs2; + } u{static_cast<__hip_bfloat162_raw>(value)}; + u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); + return static_cast<__hip_bfloat162>(u.bf162_raw); +#else + static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); + union u_hold { + __hip_bfloat162_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} + +__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) + // The api expects an ext_vector_type of half + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); + union { + __half2_raw h2r; + vec_fp162 fp16; + } u {static_cast<__half2_raw>(value)}; + u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); + return static_cast<__half2>(u.h2r); +#else + static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); + union u_hold { + __half2_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} +#define ATOMICADD preview_unsafeAtomicAdd #define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) #else #define ATOMICADD atomicAdd diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index a80c51fa6a9cb..e739d7d2ecee2 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -282,7 +282,7 @@ void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); - constexpr int num_outputs = thrust::tuple_size::value; + constexpr int num_outputs = std::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index c6d3c25200d50..4c5eabd049687 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -242,7 +242,7 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const int64_t max_target_length = 0; auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong)); - auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr(); + auto tg_batch_offsets_data = tg_batch_offsets.template mutable_data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; for (int64_t i = 0; i < batch_size; i++) { @@ -304,12 +304,12 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const ctc_loss_log_alpha_gpu_kernel<<>>( log_alpha.mutable_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, neg_log_likelihood.mutable_data_ptr(), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, BLANK); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(neg_log_likelihood, log_alpha); @@ -613,7 +613,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ int64_t max_target_length; auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong))); - auto tg_batch_offsets_data = tg_batch_offsets.mutable_data_ptr(); + auto tg_batch_offsets_data = tg_batch_offsets.template mutable_data_ptr(); if (targets.dim() == 1) { // concatenated targets int64_t pos = 0; max_target_length = 0; @@ -663,11 +663,11 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ dim3 grid(1, (batch_size+threads_batch-1)/threads_batch); ctc_loss_backward_log_beta_gpu_kernel<<>> (log_beta.mutable_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, BLANK); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -717,14 +717,14 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (grad.mutable_data_ptr(), grad_out.const_data_ptr(), grad_out.stride(0), log_alpha.const_data_ptr(), log_beta.const_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), neg_log_likelihood.const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // small problem, use naive algorithm @@ -740,14 +740,14 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (grad.mutable_data_ptr(), grad_out.const_data_ptr(), grad_out.stride(0), log_alpha.const_data_ptr(), log_beta.const_data_ptr(), - log_probs.const_data_ptr(), input_lengths_t.const_data_ptr(), log_probs.size(0), - targets.const_data_ptr(), target_lengths_t.const_data_ptr(), max_target_length, + log_probs.const_data_ptr(), input_lengths_t.template const_data_ptr(), log_probs.size(0), + targets.const_data_ptr(), target_lengths_t.template const_data_ptr(), max_target_length, neg_log_likelihood.const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), - tg_batch_offsets.const_data_ptr(), tg_target_stride, + tg_batch_offsets.template const_data_ptr(), tg_target_stride, batch_size, num_labels, BLANK, zero_infinity); C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors } @@ -765,7 +765,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_ (batch_size+threads_batch-1)/threads_batch); ctc_loss_zero_padded_gradients<<>>( grad.mutable_data_ptr(), - input_lengths_t.const_data_ptr(), + input_lengths_t.template const_data_ptr(), grad.stride(0), grad.stride(1), grad.stride(2), diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 8811f8dc5117e..d4eb1b792e7f1 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -212,7 +212,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { std::nullopt /* memory format */ ); at::cuda::memcpy_and_sync( - (void*)pinned_num_nonzeros_h.const_data_ptr(), + pinned_num_nonzeros_h.template data_ptr(), num_nonzeros.get(), sizeof(int) * num_chunks, cudaMemcpyDeviceToHost, @@ -220,7 +220,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { int64_t num_nonzeros_h = 0; for (int64_t idx = 0; idx < num_chunks; idx++) { - num_nonzeros_h += (int)*(pinned_num_nonzeros_h.const_data_ptr() + idx); + num_nonzeros_h += pinned_num_nonzeros_h.template const_data_ptr()[idx]; } // num_nonzeros_h = (int)*(pinned_num_nonzeros_h.const_data_ptr()); // expected output size is num_nonzeros x ndim @@ -267,8 +267,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out) { ((int*)num_nonzeros.get()) + idx, remaining, stream)); - curr_nonzeros += - (int)*(pinned_num_nonzeros_h.const_data_ptr() + idx); + curr_nonzeros += pinned_num_nonzeros_h.template const_data_ptr()[idx]; } if (num_nonzeros_h > 0 && self.dim() > 1) { TensorDims dims; diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 382a5a065b300..8971e05094651 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -958,8 +958,9 @@ void dispatch_fp8_rowwise_kernel_on_sm( const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; const bool sm9x = properties != nullptr && properties->major == 9; const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; const bool sm12x = properties != nullptr && properties->major == 12; - if (!(sm89 || sm9x || sm10x || sm12x)) { + if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } @@ -968,7 +969,7 @@ void dispatch_fp8_rowwise_kernel_on_sm( dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm90, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (sm10x) { + } else if (sm10x || sm11x) { dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm100, Types...>(XQ, WQ, x_scale, w_scale, bias, out); diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index ab38c1975d147..031e5b3c4f14e 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -44,16 +44,13 @@ void isneginf_kernel_impl(TensorIteratorBase &iter) { void clamp_kernel_impl(TensorIteratorBase& iter) { AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_cuda", [&] { gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { - // Propagate nan, which doesn't propagate automatically for ROCm - if (at::_isnan(v)) { - return v; - } if (at::_isnan(lower)) { - return lower; - } if (at::_isnan(upper)) { - return upper; - } else { - return ::min(::max(v, lower), upper); - } + scalar_t result = ::min(::max(v, lower), upper); + + result = at::_isnan(upper) ? upper : result; + result = at::_isnan(lower) ? lower : result; + result = at::_isnan(v) ? v : result; + + return result; }); }); } diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index b46bbaa6500b9..5ccc1143d4dc1 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -806,7 +806,7 @@ static void upsample_gen2d_aa_out_cuda_template( using accscalar_t = at::acc_type; auto idata = input.packed_accessor64(); - auto odata = output_c.packed_accessor64(); + auto odata = output_c.template packed_accessor64(); const accscalar_t height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales_h); diff --git a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu index aae4625f6a39e..159de54156dd6 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest3d.cu @@ -189,7 +189,7 @@ static void upsample_nearest3d_out_cuda_template( using accscalar_t = at::acc_type; auto idata = input.const_data_ptr(); - auto odata = output_c.mutable_data_ptr(); + auto odata = output_c.template mutable_data_ptr(); const float depth_scale = compute_scales_value(scales_d, input_depth, output_depth); const float height_scale = compute_scales_value(scales_h, input_height, output_height); diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index cd48c16a32eb9..d144a9954ed33 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -3,7 +3,6 @@ #include -#include #include #include @@ -63,9 +62,7 @@ __global__ void RowwiseMomentsCUDAKernel( val_shared_ptr); } if (threadIdx.x == 0) { - T_ACC m1; - T_ACC m2; - thrust::tie(m2, m1) = welford_op.project(val); + auto [m2, m1] = welford_op.project(val); mean[i] = m1; rstd[i] = c10::cuda::compat::rsqrt(m2 + static_cast(eps)); } diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 730a7ea910961..937008f1e83bd 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -1,10 +1,9 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include -#include - #include #include #include @@ -86,9 +85,7 @@ __global__ void RowwiseMomentsCUDAKernel( val_shared_ptr); if (threadIdx.x == 0) { - T_ACC m1; - T_ACC m2; - thrust::tie(m2, m1) = welford_op.project(val); + auto [m2, m1] = welford_op.project(val); if constexpr (!rms_norm){ mean[i] = m1; rstd[i] = c10::cuda::compat::rsqrt(m2 + eps); diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index d5102910c6471..74a3f0afb9c11 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -774,7 +774,7 @@ void raw_cudnn_convolution_forward_out_32bit( args, "Forward algorithm: ", static_cast(fwdAlgPerf.algo), - "\n"); + '\n'); }); } diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.cpp b/aten/src/ATen/native/kleidiai/kai_kernels.cpp index ce0f10bf6df1f..1313f98f90109 100644 --- a/aten/src/ATen/native/kleidiai/kai_kernels.cpp +++ b/aten/src/ATen/native/kleidiai/kai_kernels.cpp @@ -21,18 +21,27 @@ void kai_pack_int4_rhs( const int64_t n, const int64_t k, const int64_t bl) { - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - auto& params = kernel_packet.rhs_pack_params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - kai_pack_rhs_channelwise_int4( - kernel_packet, weight_packed, weight, scales, bias, n, k); + if (weight.scalar_type() == at::kBFloat16) { + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + kai_pack_rhs_channelwise_int4( + kernel_packet, weight_packed, weight, scales, bias, n, k); + } else { + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + kai_pack_rhs_channelwise_int4( + kernel_packet, weight_packed, weight, scales, bias, n, k); + } } else if (!(bl % 32) && !(k % bl)) { // Groupwise auto kernel_packet = kai_select_groupwise_matmul_ukernel( @@ -63,19 +72,29 @@ void kai_pack_int4_rhs( size_t kai_pack_rhs_int4_size( const int64_t n, const int64_t k, - const int64_t bl) { + const int64_t bl, + at::ScalarType tensor_dtype) { size_t packed_size = n * k; - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { - // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - const auto& ukernel = kernel_packet.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + if (tensor_dtype == at::kBFloat16) { + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + } else { + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + } } else if (!(bl % 32) && !(k % bl)) { // Groupwise auto kernel_packet = kai_select_groupwise_matmul_ukernel( @@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise( const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride; const int64_t m_idx = thread_id * vec_per_thread; auto lhs_packed_ptr = lhs_packed_base + - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( - m_idx, k, mr, kr, sr); + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); const int64_t vec_num = (thread_id == num_threads - 1) ? (m - vec_per_thread * thread_id) : vec_per_thread; @@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride; const int64_t m_idx = thread_id * vec_per_thread; auto lhs_packed_ptr = lhs_packed_base + - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( - m_idx, k, mr, kr, sr); + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); const int64_t vec_num = (thread_id == num_threads - 1) ? (m - vec_per_thread * thread_id) : vec_per_thread; @@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise( }); } -void kai_quant_pack_lhs_int4_mm( +static void kai_quant_pack_lhs_int4_mm_bf16_channelwise( const Tensor& output, const Tensor& input, const Tensor& weight, const int64_t m, const int64_t n, + const int64_t k) { + // Kernel IDs for GEMM and GEMV + constexpr kai_kernel_id gemm_id = + kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm; + constexpr kai_kernel_id gemv_id = + kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod; + + // Get total threads and select kernel + const int64_t total_threads = at::get_num_threads(); + auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id); + if (cpuinfo_has_arm_i8mm() && m > 1) { + kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id); + } + + // Thread blocking parameters + const int64_t n_step = kernel_packet.ukernel.get_n_step(); + const size_t mr = kernel_packet.ukernel.get_mr(); + const size_t kr = kernel_packet.ukernel.get_kr(); + const size_t sr = kernel_packet.ukernel.get_sr(); + + const size_t lhs_packed_size = + kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr); + auto lhs_packed = std::make_unique(lhs_packed_size); + uint8_t* dst_act_mtx_bf16 = reinterpret_cast(output.data_ptr()); + const uint8_t* lhs_native_mtx_bf16 = + reinterpret_cast(input.data_ptr()); + const uint8_t* rhs_packed_mtx_qs4cx = + reinterpret_cast(weight.data_ptr()); + uint8_t* lhs_packed_base = lhs_packed.get(); + + constexpr int32_t element_size = sizeof(uint16_t); + const size_t lhs_stride = k * element_size; + const size_t dst_stride = n * element_size; + + // LHS quantization packing + int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr); + int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread; + const size_t src_stride = vec_per_thread * lhs_stride; + + auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) { + const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride; + const int64_t m_idx = thread_id * vec_per_thread; + auto lhs_packed_ptr = lhs_packed_base + + kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr); + const int64_t vec_num = (thread_id == num_threads - 1) + ? (m - vec_per_thread * thread_id) + : vec_per_thread; + + kernel_packet.kai_run_lhs_quant_pack( + vec_num, + k, + mr, + kr, + sr, + 0, + (const uint16_t*)lhs_src_ptr, + lhs_stride, + lhs_packed_ptr); + }; + + at::parallel_for( + 0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) { + for (int64_t thread_id = begin; thread_id < end; ++thread_id) { + lhs_quant_pack(thread_id); + } + }); + + // Matrix multiplication + vec_per_thread = get_vec_per_thread(n, total_threads, n_step); + num_threads = (n + vec_per_thread - 1) / vec_per_thread; + + auto mm = [=, &kernel_packet](int64_t thread_id) { + const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx + + kernel_packet.ukernel.get_rhs_packed_offset( + thread_id * vec_per_thread, k); + auto dst_ptr = dst_act_mtx_bf16 + + kernel_packet.ukernel.get_dst_offset( + 0, thread_id * vec_per_thread, dst_stride); + const int64_t vec_num = (thread_id == num_threads - 1) + ? (n - vec_per_thread * thread_id) + : vec_per_thread; + + kernel_packet.ukernel.run_matmul( + m, + vec_num, + k, + lhs_packed_base, + rhs_packed_ptr, + (uint16_t*)dst_ptr, + dst_stride, + element_size, // dst_stride_col + -FLT_MAX, + FLT_MAX); + }; + + at::parallel_for( + 0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) { + for (int64_t thread_id = begin; thread_id < end; ++thread_id) { + mm(thread_id); + } + }); +} +void kai_quant_pack_lhs_int4_mm( + const at::Tensor& output, + const at::Tensor& input, + const at::Tensor& weight, + const int64_t m, + const int64_t n, const int64_t k, const int64_t bl) { // Prefer Channelwise kernel over Groupwise kernel for conflicting cases if (bl == k) { - kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( - output, input, weight, m, n, k); - } else if (!(bl % 32) && !(k % bl)) { + const auto input_dtype = input.dtype(); + + if (input_dtype == at::kBFloat16) { + if (cpuinfo_has_arm_bf16()) { + kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise( + output, input, weight, m, n, k); + } else { + TORCH_CHECK( + false, + "BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support."); + } + } else if (input_dtype == at::kFloat) { + kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( + output, input, weight, m, n, k); + } else { + TORCH_CHECK( + false, + "Unsupported input data type: Only Bfloat16 and Float inputs are supported."); + } + } else if ((bl % 32 == 0) && (k % bl == 0)) { kleidiai::kai_quant_pack_lhs_int4_mm_groupwise( output, input, weight, m, n, k, bl); } diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.h b/aten/src/ATen/native/kleidiai/kai_kernels.h index 9b522d7f7705a..a4179cefd06cf 100644 --- a/aten/src/ATen/native/kleidiai/kai_kernels.h +++ b/aten/src/ATen/native/kleidiai/kai_kernels.h @@ -25,7 +25,8 @@ void kai_pack_int4_rhs( size_t kai_pack_rhs_int4_size( const int64_t n, const int64_t k, - const int64_t bl); + const int64_t bl, + at::ScalarType tensor_dtype = at::kFloat); /** * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) diff --git a/aten/src/ATen/native/kleidiai/kai_pack.h b/aten/src/ATen/native/kleidiai/kai_pack.h index 4ff3371ab5e2a..d9f08333591ed 100644 --- a/aten/src/ATen/native/kleidiai/kai_pack.h +++ b/aten/src/ATen/native/kleidiai/kai_pack.h @@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4( AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + float* bias_ptr = + bias.has_value() ? bias.value().to(kFloat).data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( @@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4( auto weight_packed_data = reinterpret_cast(weight_packed.data_ptr()); const auto weight_data = weight.data_ptr(); - const auto scales_data = scales.data_ptr(); + + const auto scales_data = scales.to(kFloat).data_ptr(); if (weight_data == nullptr) { AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); @@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4( AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); } - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + float* bias_ptr = + bias.has_value() ? bias.value().to(kFloat).data_ptr() : NULL; auto& params = kernel.rhs_pack_params; kernel.kai_run_rhs_pack( diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp index 0de198d7dc012..783133b83e670 100644 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp @@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel( const kai_kernel_id id) { return channelwise_8bit_4bit_kernels.at(id); } + +// Kernel Mapping - BF16 Channelwise +std::unordered_map + bf16_channelwise_8bit_4bit_kernels = { + {kai_kernel_id:: + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + {{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod, + kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}}, + {kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + {{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm, + kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}}; + +kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel( + const kai_kernel_id id) { + return bf16_channelwise_8bit_4bit_kernels.at(id); +} } // namespace at::native::kleidiai #endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h index 8480469cdea86..cfcf7a81ba85f 100644 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h @@ -10,21 +10,32 @@ #include #include #include +#include +#include +#include #include +#include #include #include namespace at::native::kleidiai { enum class kai_kernel_id { + // FP32 inputs, 4-bit weights, FP32 output matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = - 0, // Groupwise 4 bit GEMV + 0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD) matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = - 1, // Groupwise 4 bit GEMM + 1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM) matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = - 2, // Channelwise 4 bit GEMV + 2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD) matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = - 3 // Channelwise 4 bit GEMM + 3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM) + + // BF16 inputs, 4-bit weights, BF16 output + matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod = + 4, // Channelwise 4-bit GEMV with BF16 input/output + matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm = + 5 // Channelwise 4-bit GEMM with BF16 input/output }; // Channelwise Kernel mapping @@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); kai_matmul_ukernel_f32_qa8dxp_qs4cxp( const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) @@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { kai_get_rhs_packed_size( &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){} }; struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); +// bf16 Channelwise Kernel mapping +struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp { + struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const void* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* rhs, + const float* bias, + const float* scale, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); + + kai_matmul_ukernel_bf16_qa8dxp_qs4cxp( + const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){} + }; + +struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp +kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id); + // Groupwise Kernel mapping struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; @@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); + size_t(*kai_get_lhs_quant_pack_offset)( + size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr + ); kai_matmul_ukernel_f32_qa8dxp_qs4c32p( const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) @@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { kai_get_rhs_packed_size( &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), + kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {} }; struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index d9f126938b301..fcdf39b8a9f4b 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -147,6 +147,19 @@ class MetalShaderLibrary { const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); + template + void exec_unary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + template + void exec_binary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + protected: virtual MTLLibrary_t getLibrary(); virtual MTLLibrary_t getLibrary( diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index cb488a3f5f117..5ca0ebe3de9bb 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -7,10 +7,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -630,4 +632,147 @@ inline bool needsGather(const TensorBase& t) { return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); } +template +void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace at::mps; + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_unary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto inputTensor = iter.input(0); + auto outputTensor = iter.output(0); + uint32_t length = iter.numel(); + if (length == 0) { + return; + } + auto kernel_name = fmt::format("{}_{}_{}_{}{}", + name, + iter.is_contiguous() ? "dense" : "strided", + scalarToMetalTypeString(outputTensor), + scalarToMetalTypeString(inputTensor), + fmt::format("_{}", params_type_name)); + @autoreleasepool { + auto cplState = getPipelineStateForFunc(kernel_name); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + auto computeEncoder = mpsStream->commandEncoder(); + + getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); + + [computeEncoder setComputePipelineState:cplState]; + bind_iter_tensors(computeEncoder, iter); + if (!iter.is_contiguous()) { + mtl_setArgs<2>(computeEncoder, + outputTensor.sizes(), + inputTensor.strides(), + outputTensor.strides(), + inputTensor.ndimension()); + } + detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6); + mtl_dispatch1DJob(computeEncoder, cplState, length); + + getMPSProfiler().endProfileKernel(cplState); + }); + } +} + +template +void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace mps; + // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) + // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with + // double as common dtype (because Python floating point are always 64-bit values) + TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); + + // Skip for empty iterators + if (iter.numel() == 0) { + return; + } + + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_binary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto convert_double_scalar = [](Tensor& t) { + if (t.dim() != 0) { + return; + } + if (t.scalar_type() == kDouble) { + t = t.to(kFloat); + } else if (t.scalar_type() == kComplexDouble) { + t = t.to(kComplexFloat); + } + }; + + Tensor input = iter.input(0); + Tensor other = iter.input(1); + Tensor out = iter.output(); + + convert_double_scalar(input); + convert_double_scalar(other); + + MPSStream* mpsStream = getCurrentMPSStream(); + const auto cast_needed = input.scalar_type() != other.scalar_type(); + const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + // TODO: Implicitly pass both input and output types to non-cast kernels + const auto kernel_name = cast_needed + ? fmt::format("{}_{}_cast_{}_{}", name, suffix, scalarToMetalTypeString(out), params_type_name) + : fmt::format("{}_{}_{}_{}_{}", + name, + suffix, + scalarToMetalTypeString(out), + scalarToMetalTypeString(input), + params_type_name); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = mpsStream->commandEncoder(); + auto binaryPSO = getPipelineStateForFunc(kernel_name); + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); + [computeEncoder setComputePipelineState:binaryPSO]; + // Set input and output tensors + bind_iter_tensors(computeEncoder, iter); + // Iterator is contiguous if all of its elements are dense in storage, + // i.e. it's true for both row-first and column-first tensors + if (iter.is_contiguous()) { + detail::mtl_setArg(computeEncoder, params, 3); + if (cast_needed) { + std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), + static_cast(c10::elementSize(other.scalar_type())), + static_cast(input.scalar_type()), + static_cast(other.scalar_type())}; + mtl_setBytes(computeEncoder, size_and_types, 4); + } + } else { + // Please note that shapes and strides of the iterator might be + // different than that of its operands, for example binary op + // between 4x4 tensor and scalar will result in 1D 16 element iterator + std::array ndim_and_types = {iter.ndim(), + static_cast(input.scalar_type()), + static_cast(other.scalar_type()), + static_cast(out.scalar_type())}; + mtl_setArgs<3>( + computeEncoder, params, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types); + } + mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); +} + } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/kernels/Activation.h b/aten/src/ATen/native/mps/kernels/Activation.h new file mode 100644 index 0000000000000..34ad90dd7a2a3 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Activation.h @@ -0,0 +1,16 @@ +#pragma once + +template +struct ELUParams { + T alpha; + T scale; + T input_scale; +}; + +template +struct ELUBackwardParams { + T alpha; + T scale; + T input_scale; + bool is_result; +}; diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index ae1fda66c3b38..7d1f3aa5bacf6 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -1,3 +1,4 @@ +#include #include #include #include @@ -99,6 +100,59 @@ REGISTER_BINARY_OP(hardswish_backward, float, float); REGISTER_BINARY_OP(hardswish_backward, half, half); REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +struct elu_functor { + template + inline T operator()(const T self_, const ELUParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto self = static_cast(self_); + auto neg_res = alpha * (::metal::precise::exp(self * input_scale) - 1); + return static_cast(scale * (self < 0 ? neg_res : self)); + } +}; + +struct elu_backward_functor { + template + inline T operator()( + const T grad_output_, + const T self_, + ELUBackwardParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto grad_output = static_cast(grad_output_); + auto self = static_cast(self_); + + if (params.is_result) { + auto neg_coef = input_scale * (self + alpha * scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } else { + auto neg_coef = input_scale * alpha * scale * + ::metal::precise::exp(self * input_scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } + } +}; + +#define REGISTER_ELU_OP(T) \ + typedef ELUParams ELUParams_##T; \ + REGISTER_UNARY_ALPHA_OP(elu, T, ELUParams_##T, T); + +REGISTER_ELU_OP(float); +REGISTER_ELU_OP(half); +REGISTER_ELU_OP(bfloat); + +#define REGISTER_ELU_BACKWARD_OP(T) \ + typedef ELUBackwardParams ELUBackwardParams_##T; \ + REGISTER_BINARY_ALPHA_OP(elu_backward, T, ELUBackwardParams_##T, T); + +REGISTER_ELU_BACKWARD_OP(float); +REGISTER_ELU_BACKWARD_OP(half); +REGISTER_ELU_BACKWARD_OP(bfloat); + struct leaky_relu_functor { template inline T operator()(const T x, const T negative_slope) { diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index e437ea5ed7989..802c648c888d5 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -11,8 +11,6 @@ #include #include #include -#include -#include #include #include #include @@ -119,6 +117,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_mps_out) (const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -162,6 +164,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_backward_mps_out) (const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryGradCachedGraph; @@ -202,6 +208,7 @@ Tensor relu_mps(const Tensor& self) { } std::tuple log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // NOTE: buffer is only used by CPU dispatch, we just ignore it here using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -698,194 +705,6 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } } -static void elu_variants_out_mps(const Tensor& self, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - const Tensor& result, - std::string func_name) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - auto resultMemFormat = result.suggest_memory_format(); - bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat)); - Tensor out; - if (executeGatherOp) { - out = at::empty_like(result, MemoryFormat::Contiguous); - } - - // Empty output - if (result.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to()) + ":" + - std::to_string(scale.to()) + ":" + std::to_string(input_scale.to()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) - - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - MPSGraphTensor* exponentMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:exponentTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* alphaTimesTensor = [mpsGraph multiplicationWithPrimaryTensor:exponentMinusOneTensor - secondaryTensor:alphaTensor - name:nil]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* fusedOutput = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:inputTensor - falsePredicateTensor:alphaTimesTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:fusedOutput - secondaryTensor:scaleTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - if (out.has_storage()) { - result.copy_(out); - } - } -} - -// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) -TORCH_IMPL_FUNC(elu_out_mps) -(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result) { - elu_variants_out_mps(self, alpha, scale, input_scale, result, "elu_out_mps"); -} - -TORCH_IMPL_FUNC(elu_backward_out_mps) -(const Tensor& grad_output, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - bool is_result, - const Tensor& self_or_result, - const Tensor& grad_input) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - auto gradMemFormat = grad_input.suggest_memory_format(); - bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && - grad_input.is_contiguous(gradMemFormat)); - Tensor out; - if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) { - out = at::empty_like(grad_input, MemoryFormat::Contiguous); - } - - // Empty output - if (grad_input.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - std::to_string(alpha.to()) + ":" + std::to_string(scale.to()) + ":" + - std::to_string(input_scale.to()) + ":" + std::to_string(is_result); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result); - MPSGraphTensor* lessThanZeroGradTensor = nil; - - if (is_result) { - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor - secondaryTensor:alphaTensor - name:nil]; - auto constMul = scale.to() * input_scale.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:resultPlusAlphaTensor - secondaryTensor:constMulTensor - name:nil]; - } else { - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - auto constMul = scale.to() * input_scale.to() * alpha.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor - secondaryTensor:constMulTensor - name:nil]; - } - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:scaleTensor - falsePredicateTensor:lessThanZeroGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = selfOrResultTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp); - Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp); - Placeholder gradInputPlaceholder = - Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false); - - auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfOrResultPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder); - if (out.has_storage()) { - grad_input.copy_(out); - } - } -} - TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor& output) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -896,6 +715,7 @@ static void elu_variants_out_mps(const Tensor& self, if (output.numel() == 0) return; + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // this can't pass anyway because a 0-dimensional tensor has "size" 1, which // can't be evenly halved, but give a nicer error message here. TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); @@ -1009,6 +829,7 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long"); // Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * // \log(1 + \exp(\beta * x))` element-wise. // For numerical stability the implementation reverts to the linear function @@ -1159,6 +980,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); if (result.numel() == 0) return; @@ -1207,6 +1030,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); if (grad_input.numel() == 0) @@ -1396,6 +1221,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // Empty output if (result.numel() == 0) diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm index cec8bfa2312e4..f6d3ad986ade0 100644 --- a/aten/src/ATen/native/mps/operations/ActivationKernel.mm +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -1,8 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include #include +#include #include namespace at::native { @@ -41,6 +43,30 @@ static void hardswish_backward_kernel(at::TensorIterator& iter) { lib.exec_binary_kernel(iter, "hardswish_backward"); } +static void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_mps", [&]() { + ELUParams params{alpha.to(), scale.to(), input_scale.to()}; + lib.exec_unary_kernel_with_params( + iter, "elu", params, fmt::format("ELUParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + +static void elu_backward_kernel(TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_backward_mps", [&]() { + ELUBackwardParams params{ + alpha.to(), scale.to(), input_scale.to(), is_result}; + lib.exec_binary_kernel_with_params( + iter, + "elu_backward", + params, + fmt::format("ELUBackwardParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { lib.exec_unary_kernel(iter, "leaky_relu", negative_slope); } @@ -56,6 +82,8 @@ static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& n REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel); REGISTER_DISPATCH(hardswish_stub, hardswish_kernel); REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel); +REGISTER_DISPATCH(elu_stub, elu_kernel); +REGISTER_DISPATCH(elu_backward_stub, elu_backward_kernel); REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel); REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel); diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 92f2b9c6fbf74..d75456c1ad3f0 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -80,6 +80,11 @@ static void grid_sampler_2d_mps_impl(Tensor& output, MPSGraphTensor* outputTensor_ = nil; }; + // Crashes with + // MPSGraphUtilities.mm:97:0: error: 'mps.sample_grid' op operand #0 must be tensor of mps native type values, but got + // 'tensor<2x3x5x20xcomplex>' + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "grid_sampler_2d is not supported for complex on MPS"); @autoreleasepool { std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index ca19d121bb718..00f9c96b78af8 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -240,7 +240,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, bool check_errors) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS doesn't support complex types."); TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False."); @@ -364,8 +364,7 @@ static void linalg_solve_out_mps_impl(const Tensor& A, const Tensor& info) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), - "linalg.lu_factor(): MPS doesn't support complex types."); + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats."); Tensor A_t, B_t; // If 'left' is false, reinterpret the problem so that Ax = B becomes A^T â‹… (x^T) = B^T // Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output @@ -1058,7 +1057,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const using namespace mps; checkInputsSolver(A, B, left, "linalg.solve_triangular"); - TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!"); + TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat, + "linalg.solve.triangular(); Only float is supported!"); Tensor A_t, B_t; std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr); at::native::resize_output(out, B_t.sizes()); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index f0bbcdabfa5cd..11ee09d6e23f2 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -416,6 +416,8 @@ static void nllnd_loss_forward_impl(Tensor& output, int64_t reduction, int64_t ignore_index, bool is2D) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()), + "nlld_loss for complex is not supported for MPS"); std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); reshapedTarget.push_back(1); @@ -824,6 +826,9 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) { std::string op_name = __func__; using namespace mps; + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "huber_loss for complex is not supported for MPS"); TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") TORCH_CHECK(output.is_mps()); diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 2d466f7c79436..ecd5f12df17f8 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -597,6 +597,7 @@ static void avg_pool2d_template(const Tensor& input, bool count_include_pad, const std::optional divisor_override, const std::string& op_name) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex"); const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); const bool is_backward_pass = grad_output.defined(); const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; @@ -915,6 +916,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); if (use_graph) { auto indices_memory_format = indices.suggest_memory_format(); @@ -967,6 +970,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& indices, const Tensor& grad_input) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 3747f314adfa1..e634eefee2058 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -269,17 +269,22 @@ static void reduction_out_mps(const Tensor& input_t, name:nil]; castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil]; } else if (reduction_type == MPSReductionType::NANSUM) { - // Create a 0 tensor of the same shape as inputTensor - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; - // Find NaNs - MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; - // Replace NaNs with 0 - MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:zeros - falsePredicateTensor:castInputTensor - name:nil]; - // Sum - castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + // Integral types cannot contain NaN, so just do regular sum + if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) { + castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; + } else { + // Create a 0 tensor of the same shape as inputTensor + auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; + // Find NaNs + auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; + // Replace NaNs with 0 + auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask + truePredicateTensor:zeros + falsePredicateTensor:castInputTensor + name:nil]; + // Sum + castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + } } MPSGraphTensor* outputTensor = castOutputTensor; @@ -442,6 +447,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, const std::optional& correction, bool keepdim, StdVarType stdVarType) { + TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS"); using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); diff --git a/aten/src/ATen/native/mps/operations/SoftMax.mm b/aten/src/ATen/native/mps/operations/SoftMax.mm index 8f70e216dcae8..8eb24d0cb68bf 100644 --- a/aten/src/ATen/native/mps/operations/SoftMax.mm +++ b/aten/src/ATen/native/mps/operations/SoftMax.mm @@ -39,6 +39,7 @@ static void get_shapes(MPSShape* input_shape_readonly, TORCH_IMPL_FUNC(softmax_mps_out) (const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) { TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); + TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types"); static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); if (input_.numel() == 0) { diff --git a/aten/src/ATen/native/mps/operations/SummaryOps.mm b/aten/src/ATen/native/mps/operations/SummaryOps.mm index e709ec2d4f618..21cae885c3685 100644 --- a/aten/src/ATen/native/mps/operations/SummaryOps.mm +++ b/aten/src/ATen/native/mps/operations/SummaryOps.mm @@ -18,6 +18,10 @@ MPSStream* stream = getCurrentMPSStream(); bool has_weights = weights.defined(); + // Crashes with + // MPSGraphUtilities.mm:190:0: error: 'mps.scatter' op operand #2 must be tensor of int values, but got 'tensor<5xi1>' + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "bincount is not supported for Bool"); + @autoreleasepool { std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9a1c7c790afaa..4fa24ff378d72 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4617,7 +4617,7 @@ dispatch: CompositeExplicitAutograd: permute MPS: permute_mps - SparseCPU, SparseCUDA: permute_sparse_coo + SparseCPU, SparseCUDA, SparseMPS: permute_sparse_coo tags: core - func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) @@ -12064,8 +12064,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: elu_out - MPS: elu_out_mps + CPU, CUDA, MPS: elu_out - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor structured_delegate: elu.out @@ -12078,8 +12077,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward_out - MPS: elu_backward_out_mps + CPU, CUDA, MPS: elu_backward_out - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor structured_delegate: elu_backward.grad_input diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu index e624295642422..203dafdccfcc6 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorBinaryOps.cu @@ -88,7 +88,7 @@ void _nested_op_dense_esuhm_kernel(Tensor& result, const Tensor& self, const Ten const scalar_t* self_data_ptr = self_buffer.const_data_ptr(); const scalar_t* other_data_ptr = other.const_data_ptr(); scalar_t* result_data_ptr = result_buffer.data_ptr(); - int64_t* result_offsets_ptr = result_offsets.data_ptr(); + int64_t* result_offsets_ptr = result_offsets.template data_ptr(); nested_op_dense_kernelLauncher( self_data_ptr, diff --git a/aten/src/ATen/native/sparse/cuda/SoftMax.cu b/aten/src/ATen/native/sparse/cuda/SoftMax.cu index 7e3b502bf6f41..2ee8de3fd5edf 100644 --- a/aten/src/ATen/native/sparse/cuda/SoftMax.cu +++ b/aten/src/ATen/native/sparse/cuda/SoftMax.cu @@ -307,7 +307,7 @@ std::tuple compute_pool_max( int64_t* offsets_ptr = offsets.data_ptr(); auto sorted_indices = at::empty({nnz}, indices.options()); - thrust_ptr sorted_indices_thrust_ptr(sorted_indices.data_ptr()); + thrust_ptr sorted_indices_thrust_ptr(sorted_indices.template data_ptr()); thrust::sequence( policy, sorted_indices_thrust_ptr, sorted_indices_thrust_ptr + nnz, 0); @@ -326,17 +326,17 @@ std::tuple compute_pool_max( sorted_indices_thrust_ptr + nnz, thrust::make_constant_iterator(int64_t(1)), thrust::make_discard_iterator(), - thrust_ptr(pool_sizes.data_ptr()), + thrust_ptr(pool_sizes.template data_ptr()), [offsets_ptr] __device__(int64_t x, int64_t y) { return offsets_ptr[x] == offsets_ptr[y]; }); auto new_sz = thrust::distance( - thrust_ptr(pool_sizes.data_ptr()), new_end.second); + thrust_ptr(pool_sizes.template data_ptr()), new_end.second); pool_sizes.resize_({new_sz}); auto pool_offsets = pool_sizes.clone(); thrust_ptr pool_offsets_thrust_ptr( - pool_offsets.data_ptr()); + pool_offsets.template data_ptr()); thrust::exclusive_scan( policy, pool_offsets_thrust_ptr, @@ -353,9 +353,9 @@ std::tuple compute_pool_max( auto mx_buffer_ptr = mx_buffer.data_ptr(); - auto pool_sizes_ptr = pool_sizes.data_ptr(); - auto sorted_indices_ptr = sorted_indices.data_ptr(); - auto pool_offsets_ptr = pool_offsets.data_ptr(); + auto pool_sizes_ptr = pool_sizes.template data_ptr(); + auto sorted_indices_ptr = sorted_indices.template data_ptr(); + auto pool_offsets_ptr = pool_offsets.template data_ptr(); thrust::for_each( policy, diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu index f8923dd1a61c1..bb4b095d7f12e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu @@ -500,7 +500,7 @@ Tensor reduce_sparse_csr_dim0_cuda_template(const Tensor& sparse, ReductionOp ro AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "reduce_sparse_csr_dim0_cuda_indices", [&]() { index_t* col_indices_ptr = col_indices.data_ptr(); - index_t* new_col_indices_ptr = new_col_indices.data_ptr(); + index_t* new_col_indices_ptr = new_col_indices.template data_ptr(); reduce_sparse_csr_dim0_cuda_kernel<<>>(new_values_acc_ptr, new_col_indices_ptr, new_nnz, diff --git a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm index 3da1cb5da53c8..3b8fd096f495c 100644 --- a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm +++ b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm @@ -488,32 +488,58 @@ Tensor addmm_sparse_dense_mps( TORCH_CHECK(t_.sparse_dim() == src_.sparse_dim(), "mul(sparse, sparse): must have same sparse_dim, got ", t_.sparse_dim(), " vs ", src_.sparse_dim()); - TORCH_CHECK(t_.sizes().equals(src_.sizes()), - "mul(sparse, sparse): sizes must match exactly (no broadcasting)."); - // Coalesce and early-exit on structurally empty operands + // Coalesce and structural info auto lhs = t_.coalesce(); auto rhs = src_.coalesce(); const int64_t lhs_nnz = lhs._nnz(); const int64_t rhs_nnz = rhs._nnz(); - if (!lhs_nnz || !rhs_nnz) { - r_.resize_as_(lhs); - return r_.zero_(); - } + const int64_t sd = lhs.sparse_dim(); // dtype checks and promotion auto commonDtype = at::result_type(lhs, rhs); TORCH_CHECK(canCast(commonDtype, r_.scalar_type()), "Can't convert result type ", commonDtype, " to output ", r_.scalar_type()); - const int64_t ndim_i = lhs.sparse_dim(); + // sparse sizes must match exactly, dense tails may broadcast + TORCH_CHECK(lhs.sizes().slice(0, sd).equals(rhs.sizes().slice(0, sd)), + "mul(sparse, sparse): sparse sizes must match exactly."); + + // dense tails and broadcasted dense tail + auto lhs_dense = lhs.sizes().slice(sd); + auto rhs_dense = rhs.sizes().slice(sd); + std::vector out_dense_vec = at::infer_size(lhs_dense, rhs_dense); + at::IntArrayRef out_dense(out_dense_vec); + + // full output sizes: [sparse_sizes] + [out_dense] + std::vector out_sizes; + out_sizes.reserve(sd + static_cast(out_dense.size())); + out_sizes.insert(out_sizes.end(), lhs.sizes().begin(), lhs.sizes().begin() + sd); + out_sizes.insert(out_sizes.end(), out_dense.begin(), out_dense.end()); + r_.sparse_resize_(out_sizes, sd, static_cast(out_dense.size())); + + const auto device = r_.device(); + + // if either is structurally empty, produce an empty sparse result with correct shape + if (!lhs_nnz || !rhs_nnz) { + Tensor out_indices = at::empty({sd, 0}, at::device(device).dtype(at::kLong)); + + std::vector out_val_sizes; + out_val_sizes.reserve(1 + out_dense.size()); + out_val_sizes.push_back(0); + out_val_sizes.insert(out_val_sizes.end(), out_dense.begin(), out_dense.end()); + + Tensor out_values = at::empty(out_val_sizes, at::device(device).dtype(r_.scalar_type())); + + alias_into_sparse(r_, out_indices, out_values); + r_._coalesced_(true); + return r_; + } - // ndim_i == 0, at most one structural entry - if (ndim_i == 0) { - r_.resize_as_(lhs); + if (sd == 0) { const bool has = (lhs_nnz && rhs_nnz); - auto out_indices = lhs._indices().narrow(1, 0, has ? 1 : 0); + auto out_indices = at::empty({0, has ? 1 : 0}, lhs._indices().options()); Tensor lhs_vals = lhs._values().to(commonDtype); Tensor rhs_vals = rhs._values().to(commonDtype); @@ -531,7 +557,6 @@ Tensor addmm_sparse_dense_mps( } // General path, intersect keys, then gather + multiply on GPU - const auto device = r_.device(); auto stream = getCurrentMPSStream(); auto lhs_indices = lhs._indices().contiguous(); @@ -540,8 +565,8 @@ Tensor addmm_sparse_dense_mps( auto rhs_values = rhs._values().to(commonDtype).contiguous(); // Flatten sparse indices to keys - auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i)); - auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i)); + auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, sd)); + auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, sd)); // Intersect sorted keys (search the shorter in the longer) const bool A_is_lhs = (lhs_nnz <= rhs_nnz); @@ -555,35 +580,49 @@ Tensor addmm_sparse_dense_mps( const auto M = static_cast(M_int64); // number of structural matches - r_.resize_as_(lhs); + auto lhs_match = outA_idx.narrow(0, 0, M_int64); + auto rhs_match = outB_idx.narrow(0, 0, M_int64); - auto out_indices = at::empty({ndim_i, static_cast(M)}, at::device(device).dtype(at::kLong)); - auto lhs_match = outA_idx.narrow(0, 0, M); - auto rhs_match = outB_idx.narrow(0, 0, M); - auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec(); int64_t cols64 = 1; - for (auto s : dense_sizes_vec) cols64 *= s; + for (auto s : out_dense) cols64 *= s; const uint32_t cols = static_cast(std::max(cols64, 1)); - auto to2d = [&](Tensor t, int64_t nnz) -> Tensor { - const int64_t t_cols = t.numel() / nnz; - if (t_cols == cols64) { - return t.view({nnz, cols64}); + // to broadcast [nnz, *in_dense] -> [nnz, *out_dense] -> [nnz, cols] + auto broadcast_to_out2d = [&](const Tensor& vals, int64_t nnz, at::IntArrayRef in_dense) -> Tensor { + const int64_t d_in = in_dense.size(); + const int64_t d_out = out_dense.size(); + + std::vector view_shape; + view_shape.reserve(1 + d_out); + view_shape.push_back(nnz); + for (int64_t i = 0; i < d_out - d_in; ++i) { + view_shape.push_back(1); } - return t.view({nnz, 1}).expand({nnz, cols64}).contiguous(); + view_shape.insert(view_shape.end(), in_dense.begin(), in_dense.end()); + + std::vector expand_shape; + expand_shape.reserve(1 + d_out); + expand_shape.push_back(nnz); + expand_shape.insert(expand_shape.end(), out_dense.begin(), out_dense.end()); + + Tensor v = vals.view(view_shape).expand(expand_shape); + return (cols64 > 0) ? v.contiguous().view({nnz, cols64}) + : v.contiguous().view({nnz, 0}); }; - // make both sides 2d [nnz, cols] buffers so the kernel can index it - auto lhs_vals2d = to2d(lhs_values, lhs_nnz); - auto rhs_vals2d = to2d(rhs_values, rhs_nnz); + // make both sides broadcasted 2d [nnz, cols] buffers so the kernel can index it + auto lhs_vals2d = broadcast_to_out2d(lhs_values, lhs_nnz, lhs_dense); + auto rhs_vals2d = broadcast_to_out2d(rhs_values, rhs_nnz, rhs_dense); std::vector out_val_sizes; - out_val_sizes.reserve(1 + dense_sizes_vec.size()); + out_val_sizes.reserve(1 + out_dense.size()); out_val_sizes.push_back(static_cast(M)); - out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end()); + out_val_sizes.insert(out_val_sizes.end(), out_dense.begin(), out_dense.end()); auto out_values = at::empty(out_val_sizes, lhs_values.options()); - if (M > 0) { + Tensor out_indices; + if (M > 0 && cols64 > 0) { + out_indices = at::empty({sd, M}, at::device(device).dtype(at::kLong)); dispatch_sync_with_rethrow(stream->queue(), ^() { @autoreleasepool { auto pso = lib.getPipelineStateForFunc( @@ -602,11 +641,19 @@ Tensor addmm_sparse_dense_mps( lhs_match, rhs_match, lhs_indices, out_indices, out_values, - std::array{static_cast(ndim_i), static_cast(lhs_nnz)}, + std::array{static_cast(sd), static_cast(lhs_nnz)}, std::array{M, cols}); [enc dispatchThreads:grid threadsPerThreadgroup:tgs]; } }); + } else if (M > 0) { + // just select the matching coordinates + Tensor src_indices_for_out = A_is_lhs ? lhs_indices : rhs_indices; + Tensor src_match_for_out = A_is_lhs ? lhs_match : rhs_match; + out_indices = src_indices_for_out.index_select(1, src_match_for_out); + } else { + // M == 0 + out_indices = at::empty({sd, 0}, at::device(device).dtype(at::kLong)); } if (r_.scalar_type() != commonDtype) { diff --git a/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal b/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal index dbd1a4548f9ee..96993de59e5f3 100644 --- a/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal +++ b/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal @@ -313,7 +313,7 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2); constant uint2& dims_output [[buffer(8)]], \ uint3 gid [[thread_position_in_grid]]); -INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL); +INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_FUSED_GATHER_MUL); #define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \ diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index a522e7ab76cf4..923b7119a42fc 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -65,6 +65,7 @@ list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda_event_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu diff --git a/aten/src/ATen/test/cuda_event_test.cpp b/aten/src/ATen/test/cuda_event_test.cpp new file mode 100644 index 0000000000000..7c58688e1ef9d --- /dev/null +++ b/aten/src/ATen/test/cuda_event_test.cpp @@ -0,0 +1,36 @@ +#include + +#include +#include +#include + +TEST(CUDAEventTest, testCUDAExternalEvent) { + if (!at::cuda::is_available()) { + return; + } + + // Create two external CUDA events + unsigned int flags = cudaEventDefault | cudaEventExternal; + auto event1 = at::cuda::CUDAEvent(flags); + auto event2 = at::cuda::CUDAEvent(flags); + // Ensure external CUDAEvent remain valid and functional after being moved. + auto start_event = std::move(event1); + auto end_event = std::move(event2); + + auto stream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(stream); + + auto graph = at::cuda::CUDAGraph(); + graph.capture_begin(); + start_event.record(); + at::cuda::sleep(100000); + end_event.record(); + graph.capture_end(); + + // External events should correctly record timestamps even when used inside + // CUDA graphs, and elapsed_time() between them should be positive. + stream.synchronize(); + graph.replay(); + at::cuda::device_synchronize(); + EXPECT_TRUE(start_event.elapsed_time(end_event) > 0); +} diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index f7206cc340973..7ee77f53d5377 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -843,22 +843,22 @@ class AssertVectorized std::stringstream stream; stream.precision(std::numeric_limits::max_digits10); stream << "Failure Details:\n"; - stream << additionalInfo << "\n"; + stream << additionalInfo << '\n'; if (hasSeed) { - stream << "Test Seed to reproduce: " << testSeed << "\n"; + stream << "Test Seed to reproduce: " << testSeed << '\n'; } if (argSize > 0) { stream << "Arguments:\n"; - stream << "#\t " << arg0 << "\n"; + stream << "#\t " << arg0 << '\n'; if (argSize == 2) { - stream << "#\t " << arg1 << "\n"; + stream << "#\t " << arg1 << '\n'; } if (argSize == 3) { - stream << "#\t " << arg2 << "\n"; + stream << "#\t " << arg2 << '\n'; } } stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act; @@ -890,7 +890,7 @@ class AssertVectorized else if (checkWithTolerance) { for (const auto i : c10::irange(sizeX)) { - EXPECT_EQ(nearlyEqual(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << "\n" << getDetail(i / unitStorageCount); + EXPECT_EQ(nearlyEqual(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << '\n' << getDetail(i / unitStorageCount); if (::testing::Test::HasFailure()) return true; } @@ -1116,11 +1116,11 @@ void test_binary_fp8( if (is_bit_wise) { EXPECT_EQ(static_cast(ref_res_scalar), static_cast(res_scalar)) << "Test failed for input0: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << '\n'; } else { EXPECT_EQ(ref_res_scalar, res_scalar) << "Test failed for input0: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e4m3fn_to_fp32_value(f8_1.x) << '\n'; } } else { at::vec::cvtfp8e5m2_fp32(_mm512_castsi512_si128(res), res_fp32_512); @@ -1128,11 +1128,11 @@ void test_binary_fp8( if (is_bit_wise) { EXPECT_EQ(static_cast(ref_res_scalar), static_cast(res_scalar)) << "Test failed for input0: " << c10::detail::fp8e5m2_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << '\n'; } else { EXPECT_EQ(ref_res_scalar, res_scalar) << "Test failed for input0: " << c10::detail::fp8e5m2_to_fp32_value(f8_0.x) - << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << "\n"; + << " input1: " << c10::detail::fp8e5m2_to_fp32_value(f8_1.x) << '\n'; } } } diff --git a/benchmarks/dynamo/microbenchmarks/.gitignore b/benchmarks/dynamo/microbenchmarks/.gitignore new file mode 100644 index 0000000000000..c627000badbf8 --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/.gitignore @@ -0,0 +1 @@ +*.prof diff --git a/benchmarks/dynamo/microbenchmarks/optree_tree_map.py b/benchmarks/dynamo/microbenchmarks/optree_tree_map.py new file mode 100644 index 0000000000000..6421bd900e663 --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/optree_tree_map.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +import argparse +import time +from pathlib import Path + +import optree + +import torch +import torch._dynamo +from torch._dynamo.debug_utils import profile_to_file + + +PROFILE_PATH = Path(__file__).with_name("optree_tree_map.prof") + + +def make_tensor_tree(depth: int, branching_factor: int, tensor_size: int, device: str): + """Create a moderately deep pytree populated with tensors.""" + + def _make_level(level: int): + if level == 0: + return torch.randn(tensor_size, tensor_size, device=device) + + children = [_make_level(level - 1) for _ in range(branching_factor)] + return { + "tensor": torch.randn(tensor_size, tensor_size, device=device), + "list": list(children), + "tuple": tuple(children), + } + + return _make_level(depth) + + +def add_leaf(lhs: torch.Tensor, *rest: torch.Tensor) -> torch.Tensor: + out = lhs + for other in rest: + out = out + other + return out + + +def optree_tree_map_loop(lhs, rhs, loop_iters): + tree = lhs + for _ in range(loop_iters): + tree = optree.tree_map( + add_leaf, + tree, + rhs, + namespace="torch", + ) + return tree + + +def _capture_compile_profile(args, lhs, rhs) -> None: + profile_path = Path(args.profile_out) + profile_path.parent.mkdir(parents=True, exist_ok=True) + + @profile_to_file(str(profile_path)) + def _run_compile() -> None: + torch._dynamo.reset() + compiled = torch.compile( + optree_tree_map_loop, + backend="eager", + fullgraph=True, + ) + compiled(lhs, rhs, args.loop_iters) + + print(f"Collecting compile-only cProfile at {profile_path}") + _run_compile() + + +def _parse_args(): + parser = argparse.ArgumentParser() + default_device = "cuda" if torch.cuda.is_available() else "cpu" + parser.add_argument("--device", default=default_device, help="Device to run on") + parser.add_argument( + "--loop-iters", + type=int, + default=50, + help="Number of tree_map calls per compiled invocation", + ) + parser.add_argument( + "--tree-depth", type=int, default=2, help="Depth of the constructed pytree" + ) + parser.add_argument( + "--branching-factor", + type=int, + default=2, + help="Branching factor for list/tuple nodes", + ) + parser.add_argument( + "--tensor-size", + type=int, + default=1, + help="Edge length for square tensor leaves", + ) + parser.add_argument( + "--profile-out", + default=str(PROFILE_PATH), + help="Destination .prof file for the compile-time cProfile", + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + lhs = make_tensor_tree( + args.tree_depth, args.branching_factor, args.tensor_size, args.device + ) + rhs = make_tensor_tree( + args.tree_depth, args.branching_factor, args.tensor_size, args.device + ) + + t0 = time.perf_counter() + _capture_compile_profile(args, lhs, rhs) + t1 = time.perf_counter() + print(f"Took {t1 - t0:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index 4c131843b372b..7a8f0988a1fbf 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -266,7 +266,13 @@ def _print_perf_result(self, results, test_case): print( f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}" ) - print(f"Peak Memory (KB) : {results['peak_memory']}\n") + print(f"Peak Memory (KB) : {results['peak_memory']}") + # Calculate and print memory bandwidth if operator provides memory traffic + if results.get("memory_bandwidth_gb_s") is not None: + print( + f"Memory Bandwidth (GB/s) : {results['memory_bandwidth_gb_s']:.2f}" + ) + print() def _perf_result_to_dict(self, results, test_case): """This function is the parallel of _print_perf_result, which instead of @@ -711,6 +717,17 @@ def run(self): result_dict = dict() result_dict["reported_run_time_us"] = [r[0] for r in results] result_dict["peak_memory"] = results[0][1] + + # Calculate memory bandwidth if operator provides memory traffic + memory_traffic_bytes = test_case.op_bench.get_memory_traffic_bytes() + if memory_traffic_bytes is not None: + execution_time_s = result_dict["reported_run_time_us"][0] / 1e6 + result_dict["memory_bandwidth_gb_s"] = ( + memory_traffic_bytes / execution_time_s / 1e9 + ) + else: + result_dict["memory_bandwidth_gb_s"] = None + self._print_perf_result(results=result_dict, test_case=test_case) # output results to csv diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index fa022417da451..ad9180a53da04 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -118,6 +118,44 @@ def test_name(self, **kargs): name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") return name + def get_memory_traffic_bytes(self): + """Return the number of bytes read/written by this operator. + + Override this method in subclasses for operations with non-standard memory patterns + (e.g., matmul which is compute-bound rather than memory-bound). + + The framework will use this value along with execution time to compute + and report memory bandwidth in GB/s. + + Default implementation assumes a pointwise-like operation: + - Reads: all input tensors + - Writes: output tensor (estimated as size of largest input) + + This default works correctly for: + - Element-wise operations (add, mul, relu, etc.) + - Activations (gelu, sigmoid, etc.) + - Optimizers (SGD, Adam, etc.) + - Reductions (sum, mean, etc. - may underestimate writes) + + Returns: + int or None: Total bytes transferred (reads + writes), or None if not applicable + """ + if not hasattr(self, "inputs") or not self.inputs: + return None + + input_tensors = [v for v in self.inputs.values() if isinstance(v, torch.Tensor)] + if not input_tensors: + return None + + # Calculate total bytes read from all inputs + bytes_read = sum(t.numel() * t.element_size() for t in input_tensors) + + # Estimate output size as the largest input (common for pointwise ops) + largest_input = max(input_tensors, key=lambda t: t.numel()) + bytes_written = largest_input.numel() * largest_input.element_size() + + return bytes_read + bytes_written + class PyTorchOperatorTestCase: """This class includes all the information needed to benchmark an operator. diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py index 3e94a9cd7f3dc..5d9cd14ec696e 100644 --- a/benchmarks/operator_benchmark/pt/addmm_test.py +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -52,6 +52,26 @@ def init(self, M, N, K, device, dtype): def forward(self, input_one, mat1, mat2): return torch.addmm(input_one, mat1, mat2) + def get_memory_traffic_bytes(self): + """Override for addmm: input + (mat1 @ mat2) -> (M, K) + addmm computes: input_one (M, K) + mat1 (M, N) @ mat2 (N, K) + Memory traffic: read(M*K + M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + mat1 = self.inputs["mat1"] + mat2 = self.inputs["mat2"] + + M, K = input_one.shape + M_check, N = mat1.shape + N_check, K_check = mat2.shape + assert M == M_check and K == K_check and N == N_check, ( + "Matrix dimensions must match" + ) + + bytes_per_element = input_one.element_size() + total_elements = M * K + M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark) op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark) @@ -84,6 +104,26 @@ def init(self, B, M, N, K, device, dtype): def forward(self, input_one, batch1, batch2): return torch.addbmm(input_one, batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for addbmm: input + sum(batch1[i] @ batch2[i]) -> (M, N) + addbmm computes: input_one (M, N) + sum over batch of batch1 (B, M, K) @ batch2 (B, K, N) + Memory traffic: read(M*N + B*M*K + B*K*N) + write(M*N) + """ + input_one = self.inputs["input_one"] + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + + M, N = input_one.shape + B, M_check, K = batch1.shape + B_check, K_check, N_check = batch2.shape + assert M == M_check and N == N_check and B == B_check and K == K_check, ( + "Dimensions must match" + ) + + bytes_per_element = input_one.element_size() + total_elements = M * N + B * M * K + B * K * N + M * N + return total_elements * bytes_per_element + addbmm_long_configs = op_bench.cross_product_configs( B=[8, 32], diff --git a/benchmarks/operator_benchmark/pt/bmm_test.py b/benchmarks/operator_benchmark/pt/bmm_test.py index f867f6ac09f8d..234bff20fb499 100644 --- a/benchmarks/operator_benchmark/pt/bmm_test.py +++ b/benchmarks/operator_benchmark/pt/bmm_test.py @@ -52,6 +52,20 @@ def init(self, B, M, N, K, device, dtype, op_func): def forward(self, batch1, batch2): return self.op_func(batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for bmm: (B, M, N) @ (B, N, K) -> (B, M, K) + Memory traffic: read(B*M*N + B*N*K) + write(B*M*K) + """ + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + B, M, N = batch1.shape + B_check, N_check, K = batch2.shape + assert B == B_check and N == N_check, "Batch dimensions must match for bmm" + + bytes_per_element = batch1.element_size() + total_elements = B * (M * N + N * K + M * K) + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( batched_binary_ops, @@ -90,6 +104,25 @@ def init(self, B, M, N, K, device, dtype, op_func): def forward(self, input_, batch1, batch2): return self.op_func(input_, batch1, batch2) + def get_memory_traffic_bytes(self): + """Override for baddbmm: input + (batch1 @ batch2) -> (B, M, K) + Memory traffic: read(B*M*K + B*M*N + B*N*K) + write(B*M*K) + """ + input_ = self.inputs["input_"] + batch1 = self.inputs["batch1"] + batch2 = self.inputs["batch2"] + B, M, K = input_.shape + B_check1, M_check, N = batch1.shape + B_check2, N_check, K_check = batch2.shape + assert B == B_check1 == B_check2, "Batch dimensions must match" + assert M == M_check and K == K_check and N == N_check, ( + "Matrix dimensions must match" + ) + + bytes_per_element = input_.element_size() + total_elements = B * (M * K + M * N + N * K + M * K) + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( batched_ternary_ops, diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index eb94921989ccf..c7fa3f5c2381d 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -22,6 +22,24 @@ def init(self, IC, OC, kernel, stride, N, L, device): def forward(self, input): return self.conv1d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv1d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv1d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × L + input_elements = input_tensor.numel() + # Weight: OC × IC × kernel + weight_elements = self.conv1d.weight.numel() + # Output: N × OC × L_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose1dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, L, device): @@ -34,6 +52,24 @@ def init(self, IC, OC, kernel, stride, N, L, device): def forward(self, input): return self.convtranspose1d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose1d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose1d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × L + input_elements = input_tensor.numel() + # Weight: IC × OC × kernel + weight_elements = self.convtranspose1d.weight.numel() + # Output: N × OC × L_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test( configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark @@ -43,15 +79,12 @@ def forward(self, input): Conv1dBenchmark, ) - -if not torch.backends.mkldnn.is_acl_available(): - # convtranpose1d crashes with ACL, see https://github.com/pytorch/pytorch/issues/165654 - op_bench.generate_pt_test( - configs.convtranspose_1d_configs_short - + configs.conv_1d_configs_short - + configs.conv_1d_configs_long, - ConvTranspose1dBenchmark, - ) +op_bench.generate_pt_test( + configs.convtranspose_1d_configs_short + + configs.conv_1d_configs_short + + configs.conv_1d_configs_long, + ConvTranspose1dBenchmark, +) """ @@ -70,6 +103,24 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): def forward(self, input): return self.conv2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv2d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: OC × (IC/G) × kernel × kernel + weight_elements = self.conv2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose2dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): @@ -82,6 +133,24 @@ def init(self, IC, OC, kernel, stride, N, H, W, G, pad, device): def forward(self, input): return self.convtranspose2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose2d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: IC × (OC/G) × kernel × kernel + weight_elements = self.convtranspose2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class Conv2dPointwiseBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, stride, N, H, W, G, pad, device): @@ -95,6 +164,24 @@ def init(self, IC, OC, stride, N, H, W, G, pad, device): def forward(self, input): return self.conv2d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv2dPointwise: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv2d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × H × W + input_elements = input_tensor.numel() + # Weight: OC × (IC/G) × 1 × 1 + weight_elements = self.conv2d.weight.numel() + # Output: N × OC × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test( configs.conv_2d_configs_short + configs.conv_2d_configs_long, Conv2dBenchmark @@ -137,6 +224,24 @@ def init(self, IC, OC, kernel, stride, N, D, H, W, device): def forward(self, input): return self.conv3d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for Conv3d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.conv3d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × D × H × W + input_elements = input_tensor.numel() + # Weight: OC × IC × kernel × kernel × kernel + weight_elements = self.conv3d.weight.numel() + # Output: N × OC × D_out × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + class ConvTranspose3dBenchmark(op_bench.TorchBenchmarkBase): def init(self, IC, OC, kernel, stride, N, D, H, W, device): @@ -149,6 +254,24 @@ def init(self, IC, OC, kernel, stride, N, D, H, W, device): def forward(self, input): return self.convtranspose3d(input) + def get_memory_traffic_bytes(self): + """Calculate memory traffic for ConvTranspose3d: read(input + weight) + write(output)""" + input_tensor = self.inputs["input"] + # Run forward to get output shape + with torch.no_grad(): + output = self.convtranspose3d(input_tensor) + + bytes_per_element = input_tensor.element_size() + # Input: N × IC × D × H × W + input_elements = input_tensor.numel() + # Weight: IC × OC × kernel × kernel × kernel + weight_elements = self.convtranspose3d.weight.numel() + # Output: N × OC × D_out × H_out × W_out + output_elements = output.numel() + + total_elements = input_elements + weight_elements + output_elements + return total_elements * bytes_per_element + op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark) diff --git a/benchmarks/operator_benchmark/pt/matmul_test.py b/benchmarks/operator_benchmark/pt/matmul_test.py index d0c58aa16e8f3..4bde44d60f381 100644 --- a/benchmarks/operator_benchmark/pt/matmul_test.py +++ b/benchmarks/operator_benchmark/pt/matmul_test.py @@ -59,6 +59,22 @@ def init(self, M, N, K, trans_a, trans_b, device, dtype=torch.float): def forward(self, input_one, input_two): return torch.matmul(input_one, input_two) + def get_memory_traffic_bytes(self): + """Override for matmul: (M, N) @ (N, K) -> (M, K) + Memory traffic: read(M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + input_two = self.inputs["input_two"] + + # input_one and input_two are properly shaped for matmul regardless of transpose + M, N = input_one.shape + N_check, K = input_two.shape + assert N == N_check, "Matrix dimensions must match for matmul" + + bytes_per_element = input_one.element_size() + total_elements = M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_test(mm_long_configs + mm_short_configs, MatMulBenchmark) op_bench.generate_pt_gradient_test(mm_long_configs, MatMulBenchmark) diff --git a/benchmarks/operator_benchmark/pt/mm_test.py b/benchmarks/operator_benchmark/pt/mm_test.py index f9e0743ba7125..07e0b596960fe 100644 --- a/benchmarks/operator_benchmark/pt/mm_test.py +++ b/benchmarks/operator_benchmark/pt/mm_test.py @@ -47,6 +47,20 @@ def init(self, M, N, K, device, dtype, op_func): def forward(self, input_one, input_two): return self.op_func(input_one, input_two) + def get_memory_traffic_bytes(self): + """Override for matmul: (M, N) @ (N, K) -> (M, K) + Memory traffic: read(M*N + N*K) + write(M*K) + """ + input_one = self.inputs["input_one"] + input_two = self.inputs["input_two"] + M, N = input_one.shape + N_check, K = input_two.shape + assert N == N_check, "Matrix dimensions must match for matmul" + + bytes_per_element = input_one.element_size() + total_elements = M * N + N * K + M * K + return total_elements * bytes_per_element + op_bench.generate_pt_tests_from_op_list( ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark diff --git a/benchmarks/operator_benchmark/pt/optimizer_test.py b/benchmarks/operator_benchmark/pt/optimizer_test.py new file mode 100644 index 0000000000000..53bab9773def4 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/optimizer_test.py @@ -0,0 +1,65 @@ +import operator_benchmark as op_bench + +import torch +import torch.optim as optim + + +"""Microbenchmarks for optimizer operators.""" + + +optimizer_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["adamw", optim.AdamW], + ["adam", optim.Adam], + ["sgd", optim.SGD], + ["rmsprop", optim.RMSprop], + ["adagrad", optim.Adagrad], + ], +) + +optimizer_configs_long = op_bench.cross_product_configs( + shape=[(100000,), (1000000,), (10000000,)], + device=["cuda"], + tags=["long"], +) + + +class OptimizerBenchmark(op_bench.TorchBenchmarkBase): + def init(self, op_func, device, shape): + self.op_func = op_func + self.param = torch.randn( + shape, device=device, requires_grad=True, dtype=torch.float32 + ) + self.param.grad = torch.randn(shape, device=device) + + kwargs = {"momentum": 0.9} if op_func == optim.SGD else {} + self.optimizer = op_func([self.param], lr=0.001, **kwargs) + + self.inputs = {"dummy": self.param} + + def forward(self, dummy): + self.optimizer.step() + return self.param + + def get_memory_traffic_bytes(self): + # Memory traffic calculation for bandwidth + total_elements = self.param.numel() + bytes_per_element = self.param.element_size() + # SGD w/ momentum: read(param, grad, momentum) + write(param, momentum) = 5x + # Adam/AdamW: read(param, grad, exp_avg, exp_avg_sq) + write(param, exp_avg, exp_avg_sq) = 7x + # Adagrad/RMSprop: read(param, grad, state) + write(param, state) = 5x + if self.op_func in (optim.Adam, optim.AdamW): + memory_multiplier = 7 + else: + memory_multiplier = 5 + return total_elements * bytes_per_element * memory_multiplier + + +op_bench.generate_pt_tests_from_op_list( + optimizer_list, optimizer_configs_long, OptimizerBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index e9af132df28a9..b120d987514e9 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -147,7 +147,7 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - shape: tuple[int] # [B, Hq, M, Hkv, N, D] + shape: tuple[int, ...] # [B, Hq, M, Hkv, N, D] attn_type: str dtype: torch.dtype calculate_bwd_time: bool @@ -257,7 +257,7 @@ def generate_inputs( def generate_jagged_inputs( - shape: tuple[int], + shape: tuple[int, ...], query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -720,7 +720,7 @@ def print_results(results: list[Experiment], save_path: Optional[str] = None): dropout_p = 0.0 -def generate_score_mod(attn_type: str, shape: tuple[int]) -> Callable | None: +def generate_score_mod(attn_type: str, shape: tuple[int, ...]) -> Callable | None: B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 from attn_gym.mods import generate_alibi_bias, generate_tanh_softcap @@ -762,7 +762,7 @@ def score_mod_w_offset(score, b, h, m, n): prefix_length = 512 -def generate_block_mask(attn_type: str, shape: tuple[int]): +def generate_block_mask(attn_type: str, shape: tuple[int, ...]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 @@ -837,7 +837,7 @@ def decoding_w_cached_seq_len(b, h, m, n): return block_mask, mask_mod_kwargs -def get_kernel_options(attn_type: str, shape: tuple[int]): +def get_kernel_options(attn_type: str, shape: tuple[int, ...]): B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 kernel_opt_training_dict = { @@ -924,7 +924,7 @@ def get_backend_context(backend: str): def generate_FA_callable( - attn_type: str, shape: tuple[int], dtype: torch.dtype, backend: str, **kwargs + attn_type: str, shape: tuple[int, ...], dtype: torch.dtype, backend: str, **kwargs ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -983,7 +983,7 @@ def offsets_to_lengths( def generate_FD_callable( - attn_type: str, shape: tuple[int], dtype: torch.dtype + attn_type: str, shape: tuple[int, ...], dtype: torch.dtype ) -> Callable | None: if dtype not in [torch.float16, torch.bfloat16]: return None @@ -1030,7 +1030,10 @@ def flash_attn_with_kvcache_renamed(q, k, v, **kwargs): def generate_attn_mask_linear_score_mod( - shape: tuple[int], block_mask: BlockMask, score_mod: Callable, dtype: torch.dtype + shape: tuple[int, ...], + block_mask: BlockMask, + score_mod: Callable, + dtype: torch.dtype, ): B, Hq, M, N = shape if block_mask is None and score_mod is None: @@ -1055,7 +1058,7 @@ def generate_attn_mask_linear_score_mod( def generate_eager_sdpa( attn_type: str, - shape: tuple[int], + shape: tuple[int, ...], dtype: torch.dtype, block_mask: BlockMask, score_mod: Callable | None = None, diff --git a/build_variables.bzl b/build_variables.bzl index 258e739300c1e..ba856c5a97ba4 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -521,6 +521,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/comm.cpp", "torch/csrc/distributed/c10d/control_collectives/StoreCollectives.cpp", "torch/csrc/distributed/c10d/control_plane/Handlers.cpp", + "torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp", "torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp", "torch/csrc/distributed/c10d/cuda/StreamBlock.cpp", "torch/csrc/distributed/c10d/debug.cpp", @@ -731,6 +732,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/cuda/comm.cpp", "torch/csrc/cuda/memory_snapshot.cpp", "torch/csrc/cuda/CUDAPluggableAllocator.cpp", + "torch/csrc/cuda/shim_common.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp", "torch/csrc/inductor/aoti_torch/shim_cuda.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", diff --git a/c10/core/Layout.h b/c10/core/Layout.h index a85f2ee6911ce..7cd25b04c5bb6 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -3,30 +3,9 @@ #include #include -#include -#include +#include namespace c10 { -enum class Layout : int8_t { - Strided, - Sparse, - SparseCsr, - Mkldnn, - SparseCsc, - SparseBsr, - SparseBsc, - Jagged, - NumOptions -}; - -constexpr auto kStrided = Layout::Strided; -constexpr auto kSparse = Layout::Sparse; -constexpr auto kSparseCsr = Layout::SparseCsr; -constexpr auto kMkldnn = Layout::Mkldnn; -constexpr auto kSparseCsc = Layout::SparseCsc; -constexpr auto kSparseBsr = Layout::SparseBsr; -constexpr auto kSparseBsc = Layout::SparseBsc; -constexpr auto kJagged = Layout::Jagged; inline Layout layout_from_backend(Backend backend) { C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index 8c8531d014713..7271a281e5ddb 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -3,46 +3,18 @@ #include #include +#include + #include -#include #include -// Memory format is not the property of a Tensor. It is the way to tell an -// operator how the result should be organized in memory and nothing more. That -// means memory format should never be used as return value for any tensor state -// interrogation functions (internally and externally). -// -// Possible options are: -// Preserve: -// If any of the input tensors is in channels_last format, operator output -// should be in channels_last format -// -// Contiguous: -// Regardless of input tensors format, the output should be contiguous -// Tensor. -// -// ChannelsLast: -// Regardless of input tensors format, the output should be in channels_last -// format. - namespace c10 { -enum class MemoryFormat : int8_t { - Contiguous, - Preserve, - ChannelsLast, - ChannelsLast3d, - NumOptions -}; // If you are seeing this, it means that this call site was not checked if // the memory format could be preserved, and it was switched to old default // behaviour of contiguous #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format() -inline MemoryFormat get_contiguous_memory_format() { - return MemoryFormat::Contiguous; -} - inline std::ostream& operator<<( std::ostream& stream, at::MemoryFormat memory_format) { diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 00fc03bbd0fcf..56bc75e01adb1 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -48,7 +48,7 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } -void StorageImpl::incref_pyobject() const { +void StorageImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -59,12 +59,12 @@ void StorageImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void StorageImpl::decref_pyobject() const { +void StorageImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool StorageImpl::try_incref_pyobject() const { +bool StorageImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index c7dbd5c1f005b..8df32f552c754 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -105,11 +105,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { data_ptr_.clear(); } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; size_t nbytes() const { // OK to do this instead of maybe_as_int as nbytes is guaranteed positive diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 94a7375cc32fb..c890d6d084eb3 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -988,7 +988,7 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } -void TensorImpl::incref_pyobject() const { +void TensorImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -999,12 +999,12 @@ void TensorImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void TensorImpl::decref_pyobject() const { +void TensorImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool TensorImpl::try_incref_pyobject() const { +bool TensorImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 71a0195dde773..42b6bb1e80d2e 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2178,11 +2178,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; private: // See NOTE [std::optional operator usage in CUDA] diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 2604f677858d1..fd80c45fcc79e 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -43,6 +43,7 @@ set(C10_CUDA_HEADERS CUDACachingAllocator.h CUDADeviceAssertionHost.h CUDAException.h + CUDAEvent.h CUDAFunctions.h CUDAGuard.h CUDAMacros.h diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 9e7823a394302..1d70edde5a4ca 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1765,7 +1765,12 @@ class DeviceCachingAllocator { auto node_get_dependencies = [](cudaGraphNode_t n, cudaGraphNode_t* deps, size_t* count) -> void { #if (defined(CUDA_VERSION) && CUDA_VERSION >= 13000) - C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count)); + if (deps == nullptr) { + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, nullptr, count)); + } else { + cudaGraphEdgeData edgeData; + C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, &edgeData, count)); + } #else C10_CUDA_CHECK(cudaGraphNodeGetDependencies(n, deps, count)); #endif diff --git a/c10/cuda/CUDAEvent.h b/c10/cuda/CUDAEvent.h new file mode 100644 index 0000000000000..6e5205044879f --- /dev/null +++ b/c10/cuda/CUDAEvent.h @@ -0,0 +1,278 @@ +#pragma once + +#include +#include +#include +#include + +/* + * `cudaEventExternal` is a torch-specific flag that is used to + * indicate that the CUDAEvent will be used only for synchronization + * with work outside of the cuda graph, rather than creation of + * cross-stream dependencies within a cuda graph. Resources: + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-c-programming-guide/index.html#cross-stream-dependencies-and-events + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3457b81d1d32c6a00f6132fbc2693d47 + * https://docs.nvidia.com/cuda/archive/12.9.0/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e + */ +#define cudaEventExternal 0x08 + +namespace c10::cuda { + +/* + * CUDAEvents are movable not copyable wrappers around CUDA's events. + * + * CUDAEvents are constructed lazily when first recorded unless it is + * reconstructed from a cudaIpcEventHandle_t. The event has a device, and this + * device is acquired from the first recording stream. However, if reconstructed + * from a handle, the device should be explicitly specified; or if ipc_handle() + * is called before the event is ever recorded, it will use the current device. + * Later streams that record the event must match this device. + */ +struct CUDAEvent { + // Constructors + // Default value for `flags` is specified below - it's cudaEventDisableTiming + CUDAEvent() noexcept = default; + CUDAEvent(unsigned int flags) noexcept : flags_{flags} {} + + CUDAEvent(DeviceIndex device_index, const cudaIpcEventHandle_t* handle) + : device_index_(device_index) { + CUDAGuard guard(device_index_); + + C10_CUDA_CHECK(cudaIpcOpenEventHandle(&event_, *handle)); + is_created_ = true; + } + + // Note: event destruction done on creating device to avoid creating a + // CUDA context on other devices. + ~CUDAEvent() { + if (is_created_) { + CUDAGuard guard(device_index_); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_deletion( + c10::kCUDA, reinterpret_cast(event_)); + } + C10_CUDA_CHECK_WARN(cudaEventDestroy(event_)); + } + } + + CUDAEvent(const CUDAEvent&) = delete; + CUDAEvent& operator=(const CUDAEvent&) = delete; + + CUDAEvent(CUDAEvent&& other) noexcept { + moveHelper(std::move(other)); + } + CUDAEvent& operator=(CUDAEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator cudaEvent_t() const { + return event(); + } + + // Less than operator (to allow use in sets) + friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) { + return left.event_ < right.event_; + } + + std::optional device() const { + if (is_created_) { + return c10::Device(c10::kCUDA, device_index_); + } else { + return {}; + } + } + + bool isCreated() const { + return is_created_; + } + DeviceIndex device_index() const { + return device_index_; + } + cudaEvent_t event() const { + return event_; + } + + // Note: cudaEventQuery can be safely called from any device + bool query() const { + if (!is_created_) { + return true; + } + + cudaError_t err = cudaEventQuery(event_); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; + } + + void record() { + record(getCurrentCUDAStream()); + } + + void recordOnce(const CUDAStream& stream) { + if (!was_recorded_) + record(stream); + } + + // Note: cudaEventRecord must be called on the same device as the event. + void record(const CUDAStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + CUDAGuard guard(device_index_); + +#ifndef USE_ROCM + // it is an error to use cudaEventRecordExternal when not doing stream + // capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != + c10::cuda::CaptureStatus::None && + external_) + ? cudaEventRecordExternal + : cudaEventRecordDefault; + C10_CUDA_CHECK(cudaEventRecordWithFlags(event_, stream, flags)); +#else + C10_CUDA_CHECK(cudaEventRecord(event_, stream)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_record( + c10::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream())); + } + was_recorded_ = true; + } + + // Note: cudaStreamWaitEvent must be called on the same device as the stream. + // The event has no actual GPU resources associated with it. + void block(const CUDAStream& stream) { + if (is_created_) { + CUDAGuard guard(stream.device_index()); +#ifndef USE_ROCM + // it is an error to use cudaEventWaitExternal when not doing stream + // capture + unsigned int flags = (c10::cuda::currentStreamCaptureStatusMayInitCtx() != + c10::cuda::CaptureStatus::None && + external_) + ? cudaEventWaitExternal + : cudaEventWaitDefault; + C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, flags)); +#else + C10_CUDA_CHECK(cudaStreamWaitEvent(stream, event_)); +#endif + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_wait( + c10::kCUDA, + reinterpret_cast(event_), + reinterpret_cast(stream.stream())); + } + } + } + + // Note: cudaEventElapsedTime can be safely called from any device + float elapsed_time(const CUDAEvent& other) const { + TORCH_CHECK_VALUE( + !(flags_ & cudaEventDisableTiming) && + !(other.flags_ & cudaEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + // We do not strictly have to set the device index to the same as our event, + // but if we don't and the current device is not initialized, it will + // create a new cuda context, which will consume a lot of memory. + CUDAGuard guard(device_index_); + // raise cudaErrorNotReady if either event is recorded but not yet completed + C10_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + // Note: cudaEventSynchronize can be safely called from any device + void synchronize() const { + if (is_created_) { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_synchronization( + c10::kCUDA, reinterpret_cast(event_)); + } + C10_CUDA_CHECK(cudaEventSynchronize(event_)); + } + } + + // Note: cudaIpcGetEventHandle must be called on the same device as the event + void ipc_handle(cudaIpcEventHandle_t* handle) { + if (!is_created_) { + // this CUDAEvent object was initially constructed from flags but event_ + // is not created yet. + createEvent(getCurrentCUDAStream().device_index()); + } + CUDAGuard guard(device_index_); + C10_CUDA_CHECK(cudaIpcGetEventHandle(handle, event_)); + } + + private: + unsigned int flags_ = cudaEventDisableTiming; + bool is_created_ = false; + bool was_recorded_ = false; + bool external_ = false; + DeviceIndex device_index_ = -1; + cudaEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + external_ = (flags_ & cudaEventExternal) != 0; +#ifdef USE_ROCM + TORCH_CHECK(!external_, "External events are disallowed in rocm"); +#endif + flags_ &= ~cudaEventExternal; + device_index_ = device_index; + CUDAGuard guard(device_index_); + C10_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_event_creation( + c10::kCUDA, reinterpret_cast(event_)); + } + is_created_ = true; + } + + void moveHelper(CUDAEvent&& other) { + // Transfer ownership of all state from other to this + flags_ = other.flags_; + is_created_ = other.is_created_; + was_recorded_ = other.was_recorded_; + external_ = other.external_; + device_index_ = other.device_index_; + event_ = other.event_; + + // Reset other to a valid empty state to prevent double-free + // The moved-from object must not attempt to destroy the event + other.is_created_ = false; + other.event_ = cudaEvent_t{}; + } +}; + +} // namespace c10::cuda diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 674eb00035c50..48bf95bb976d8 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -320,7 +320,7 @@ void mallocAsync( TORCH_INTERNAL_ASSERT( 0 <= device && device < device_count, "Invalid device index ", - device, + static_cast(device), ": did you call init?"); // If stream is a null (default) stream, @@ -370,7 +370,7 @@ void mallocAsync( OutOfMemoryError, false, "Allocation on device ", - device, + static_cast(device), " would exceed allowed memory. (out of memory)", "\nCurrently allocated : ", format_size(pytorch_used_bytes[device]), diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index b305008d44f8c..49bad41dda866 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -17,8 +17,13 @@ std::string get_cuda_error_help(cudaError_t error) noexcept { default: help_text.append("\nSearch for `") .append(cudaGetErrorName(error)) +#if defined(USE_ROCM) + .append( + "' in https://rocm.docs.amd.com/projects/HIP/en/latest/index.html for more information."); +#else .append( "' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information."); +#endif break; } return help_text; diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 1ff0c9a12ac78..380e7939ff76c 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,22 +20,6 @@ } \ } while (0) -#define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \ - do { \ - CUresult __err = EXPR; \ - if (__err != CUDA_SUCCESS) { \ - const char* err_str; \ - CUresult get_error_str_err [[maybe_unused]] = \ - c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ - if (get_error_str_err != CUDA_SUCCESS) { \ - TORCH_WARN("CUDA driver error: unknown error"); \ - } else { \ - TORCH_WARN("CUDA driver error: ", err_str); \ - } \ - goto NEXT; \ - } \ - } while (0) - // The integer in the second column specifies the requested CUDA Driver API // version. The dynamic loader will accept a driver with a newer version, but it // ensures that the requested symbol exists in *at least* the specified version diff --git a/c10/util/C++17.h b/c10/util/C++17.h index 5dafb245f92e8..8e88a1ec50cc4 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -34,15 +34,6 @@ namespace c10 { -// std::is_pod is deprecated in C++20, std::is_standard_layout and -// std::is_trivial are introduced in C++11, std::conjunction has been introduced -// in C++17. -template -using is_pod = std::conjunction, std::is_trivial>; - -template -constexpr bool is_pod_v = is_pod::value; - namespace guts { #if defined(__HIP__) diff --git a/c10/util/Deprecated.h b/c10/util/Deprecated.h index 88440a0242eb4..3237074feff8c 100644 --- a/c10/util/Deprecated.h +++ b/c10/util/Deprecated.h @@ -1,102 +1,2 @@ #pragma once - -/** - * This file provides portable macros for marking declarations - * as deprecated. You should generally use C10_DEPRECATED, - * except when marking 'using' declarations as deprecated, - * in which case you should use C10_DEFINE_DEPRECATED_USING - * (due to portability concerns). - */ - -// Sample usage: -// -// C10_DEPRECATED void bad_func(); -// struct C10_DEPRECATED BadStruct { -// ... -// }; - -// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses -// the "__declspec(deprecated)" implementation and not the C++14 -// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on -// MSVC, but ran into issues with some older MSVC versions. -#if (defined(__cplusplus) && __cplusplus >= 201402L) -#define C10_DEPRECATED [[deprecated]] -#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] -#elif defined(__GNUC__) -#define C10_DEPRECATED __attribute__((deprecated)) -// TODO Is there some way to implement this? -#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) - -#elif defined(_MSC_VER) -#define C10_DEPRECATED __declspec(deprecated) -#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) -#else -#warning "You need to implement C10_DEPRECATED for this compiler" -#define C10_DEPRECATED -#endif - -// Sample usage: -// -// C10_DEFINE_DEPRECATED_USING(BadType, int) -// -// which is the portable version of -// -// using BadType [[deprecated]] = int; - -// technically [[deprecated]] syntax is from c++14 standard, but it works in -// many compilers. -#if defined(__has_cpp_attribute) -#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName [[deprecated]] = TypeThingy; -#endif -#endif - -#if defined(_MSC_VER) -#if defined(__CUDACC__) -// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows; -// you get the error: -// -// error: attribute does not apply to any entity -// -// So we just turn the macro off in this case. -#if defined(C10_DEFINE_DEPRECATED_USING) -#undef C10_DEFINE_DEPRECATED_USING -#endif -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = TypeThingy; -#else -// [[deprecated]] does work in windows without nvcc, though msc doesn't support -// `__has_cpp_attribute` when c++14 is supported, otherwise -// __declspec(deprecated) is used as the alternative. -#ifndef C10_DEFINE_DEPRECATED_USING -#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName [[deprecated]] = TypeThingy; -#else -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = __declspec(deprecated) TypeThingy; -#endif -#endif -#endif -#endif - -#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__) -// nvcc has a bug where it doesn't understand __attribute__((deprecated)) -// declarations even when the host compiler supports it. We'll only use this gcc -// attribute when not cuda, and when using a GCC compiler that doesn't support -// the c++14 syntax we checked for above (available in __GNUC__ >= 5) -#if !defined(__CUDACC__) -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName __attribute__((deprecated)) = TypeThingy; -#else -// using cuda + gcc < 5, neither deprecated syntax is available so turning off. -#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ - using TypeName = TypeThingy; -#endif -#endif - -#if !defined(C10_DEFINE_DEPRECATED_USING) -#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" -#define C10_DEFINE_DEPRECATED_USING -#endif +#include diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index 493c03cb42e64..8770977840cb2 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -58,6 +58,12 @@ inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) template inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) { + if (C10_UNLIKELY( + std::is_signed::value && + a == std::numeric_limits::min() && b == scalar_t(-1))) { + return a; + } + if (c10::signs_differ(a, b)) { // Subtracts one from the results of truncation division if the // divisor and dividend have different sign(bit)s and the remainder of diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 0c8f55f5061ab..f3c4ab0dc7cbc 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -68,6 +68,10 @@ inline bool has_pyobject(uint64_t combined_refcount) { return (combined_refcount & kHasPyObject) != 0; } +inline bool is_uniquely_owned(uint64_t combined_refcount) { + return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef; +} + // The only requirement for refcount increment is that it happens-before // decrement, so no additional memory ordering is needed. inline uint64_t atomic_combined_refcount_increment( @@ -287,9 +291,9 @@ class C10_API intrusive_ptr_target { * These two methods are called when the refcount transitions between one * and two and the object has a PyObject wrapper. */ - virtual void incref_pyobject() const {} - virtual void decref_pyobject() const {} - virtual bool try_incref_pyobject() const { + virtual void incref_pyobject() const noexcept {} + virtual void decref_pyobject() const noexcept {} + virtual bool try_incref_pyobject() const noexcept { return false; } @@ -363,7 +367,7 @@ class intrusive_ptr final { template friend class pybind11::class_; - void retain_() { + void retain_() noexcept { if (target_ != NullType::singleton()) { uint64_t combined = detail::atomic_combined_refcount_increment( target_->combined_refcount_, detail::kReferenceCountOne); @@ -377,9 +381,7 @@ class intrusive_ptr final { // PyObject. In other words, we need to ensure that the PyObject stays // alive now that we have a C++ reference to this object in addition to // the PyObject itself. - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { target_->incref_pyobject(); } } else { @@ -392,51 +394,60 @@ class intrusive_ptr final { void reset_() noexcept { if (target_ != NullType::singleton()) { - if (is_uniquely_owned()) { - // Both counts are 1, so there are no weak references and - // we are releasing the last strong reference. No other - // threads can observe the effects of this target_ deletion - // call (e.g. calling use_count()) without a data race. - target_->combined_refcount_.store(0, std::memory_order_relaxed); - delete target_; + reset_not_null_(target_); + } + } + + // C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here + // to avoid an extra pointer dereference in the call from reset_(). + C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept { + if (detail::is_uniquely_owned( + target->combined_refcount_.load(std::memory_order_acquire))) { + // Both counts are 1, so there are no weak references and + // we are releasing the last strong reference. No other + // threads can observe the effects of this target deletion + // call (e.g. calling use_count()) without a data race. + target->combined_refcount_.store(0, std::memory_order_relaxed); + delete target; + return; + } + + auto combined_refcount = detail::atomic_combined_refcount_decrement( + target->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + if (detail::weakcount(combined_refcount) == 1) { + delete target; return; } - - auto combined_refcount = detail::atomic_combined_refcount_decrement( - target_->combined_refcount_, detail::kReferenceCountOne); - uint32_t new_refcount = detail::refcount(combined_refcount); - bool has_pyobject = detail::has_pyobject(combined_refcount); - if (new_refcount == 0) { - bool should_delete = detail::weakcount(combined_refcount) == 1; - // See comment above about weakcount. As long as refcount>0, - // weakcount is one larger than the actual number of weak references. - // So we need to decrement it here. - if (!should_delete) { - // justification for const_cast: release_resources is basically a - // destructor and a destructor always mutates the object, even for - // const objects. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast*>(target_) - ->release_resources(); - should_delete = detail::atomic_weakcount_decrement( - target_->combined_refcount_) == 0; - } - if (should_delete) { - delete target_; - } - } else if constexpr (detail::TargetTraits::can_have_pyobject) { - // If the refcount transitioned from 2 to 1, we need to decref the - // PyObject. In other words, we don't want to keep the PyObject alive if - // there are no C++ references to this object other than the PyObject - // itself. - if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) { - target_->decref_pyobject(); - } - } else { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !has_pyobject, - "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + // See comment above about weakcount. As long as refcount>0, + // weakcount is one larger than the actual number of weak references. + // So we need to decrement it here. + release_resources_and_decrement_weakrefs_(target); + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (has_pyobject && new_refcount == 1) { + target->decref_pyobject(); } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } + } + + C10_NOINLINE static void release_resources_and_decrement_weakrefs_( + TTarget* target) noexcept { + // justification for const_cast: release_resources is basically a + // destructor and a destructor always mutates the object, even for + // const objects. + const_cast*>(target)->release_resources(); + if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) { + delete target; } } @@ -607,9 +618,8 @@ class intrusive_ptr final { */ bool is_uniquely_owned() const noexcept { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); - uint64_t combined = - target_->combined_refcount_.load(std::memory_order_acquire); - return (combined & ~detail::kHasPyObject) == detail::kUniqueRef; + return detail::is_uniquely_owned( + target_->combined_refcount_.load(std::memory_order_acquire)); } /** @@ -1174,9 +1184,7 @@ inline void incref(intrusive_ptr_target* self) { self->combined_refcount_, detail::kReferenceCountOne); #ifndef C10_MOBILE - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { self->incref_pyobject(); } #else diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index 32ffca52e4864..bfdb968ff96ab 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -3,6 +3,7 @@ #include #include +#include // GCC has __builtin_mul_overflow from before it supported __has_builtin #ifdef _MSC_VER @@ -15,31 +16,45 @@ namespace c10 { -C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { +template , int> = 0> +C10_ALWAYS_INLINE bool add_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_add_overflow(a, b, out); #else - unsigned long long tmp; -#if defined(_M_IX86) || defined(_M_X64) - auto carry = _addcarry_u64(0, a, b, &tmp); -#else - tmp = a + b; - unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); - auto carry = vector >> 63; -#endif - *out = tmp; - return carry; + if constexpr (std::is_signed_v) { + // For signed types, detect overflow by checking sign changes + volatile T tmp = a + b; + *out = tmp; + + // If both operands have the same sign, check if result changed sign + // unexpectedly. + if ((a > 0) == (b > 0)) { + if ((a > 0) && (tmp <= 0)) { + return true; // Positive overflow + } + if ((a < 0) && (tmp >= 0)) { + return true; // Negative overflow + } + } + return false; + } else { + // For unsigned types, overflow causes wrap-around + volatile T tmp = a + b; + *out = tmp; + return (tmp < a || tmp < b); + } #endif } -template +C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return add_overflows(a, b, out); +} + +template , int> = 0> C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_mul_overflow(a, b, out); #else - static_assert( - std::is_integral_v, "mul_overflows only supports integral types"); - if constexpr (std::is_signed_v) { // For signed types, use the division-based check volatile T tmp = a * b; diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 3bd9eff0fee63..d7eeb10caba1b 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -893,11 +893,13 @@ class DeviceCachingAllocator { } bool release_cached_blocks(MempoolId_t mempool_id) { + bool streams_synced = false; if (mempool_id.first == 0 && mempool_id.second == 0 && captures_underway.empty()) { synchronize_and_free_events(); // See Note [Safe to Free Blocks on BlockPool] c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; release_blocks(large_blocks); release_blocks(small_blocks); @@ -916,6 +918,12 @@ class DeviceCachingAllocator { continue; } } + + if (!streams_synced) { + // See Note [Safe to Free Blocks on BlockPool] + c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; + } TORCH_INTERNAL_ASSERT(it->second->use_count == 0); release_blocks(it->second->small_blocks); release_blocks(it->second->large_blocks); @@ -1219,6 +1227,63 @@ class DeviceCachingAllocator { allowed_memory_maximum = static_cast(fraction * device_total); set_fraction = true; } + + void createOrIncrefPool( + MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr) { + std::scoped_lock lock(mutex); + create_or_incref_pool(mempool_id, allocator); + } + + int getPoolUseCount(MempoolId_t mempool_id) { + std::scoped_lock lock(mutex); + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + return 0; + } + return it->second->use_count; + } + + // Called by XPUGraph::capture_begin + void beginAllocateToPool( + MempoolId_t mempool_id, + std::function filter) { + std::lock_guard lock(mutex); + create_or_incref_pool(mempool_id); + auto not_found = std::all_of( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first != mempool_id; }); + TORCH_CHECK( + not_found, "beginAllocateToPool: already recording to mempool_id"); + captures_underway.emplace_back(mempool_id, std::move(filter)); + } + + // Called by XPUGraph::capture_end + void endAllocateToPool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + + auto it = std::find_if( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first == mempool_id; }); + TORCH_INTERNAL_ASSERT( + it != captures_underway.end(), + "endAllocatePool: not currently recording to mempool_id"); + captures_underway.erase(it); + } + + // Called by XPUGraph::reset and MemPool::~MemPool() + void releasePool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + auto pp = get_private_pool(mempool_id); + auto uc = --(pp->use_count); + TORCH_INTERNAL_ASSERT(uc >= 0); + if (uc == 0) { + bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second; + TORCH_INTERNAL_ASSERT(inserted); + } + } }; static void local_raw_delete(void* ptr); @@ -1408,6 +1473,39 @@ class XPUAllocator : public DeviceAllocator { ". Please set within (0, 1]."); device_allocators[device]->setMemoryFraction(fraction); } + + void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator) { + assertValidDevice(device); + device_allocators[device]->createOrIncrefPool( + std::move(mempool_id), allocator); + } + + void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + assertValidDevice(device); + device_allocators[device]->beginAllocateToPool( + std::move(mempool_id), std::move(filter)); + } + + void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->endAllocateToPool(mempool_id); + } + + void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->releasePool(std::move(mempool_id)); + } + + int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + return device_allocators[device]->getPoolUseCount(std::move(mempool_id)); + } }; static XPUAllocator allocator; @@ -1464,6 +1562,92 @@ void setMemoryFraction(double fraction, DeviceIndex device) { return allocator.setMemoryFraction(fraction, device); } +void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator_ptr) { + return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); +} + +void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + return allocator.beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.endAllocateToPool(device, mempool_id); +} + +void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.releasePool(device, mempool_id); +} + +int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.getPoolUseCount(device, mempool_id); +} + REGISTER_ALLOCATOR(kXPU, &allocator) } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +// uid_ is incremented when a user creates a MemPool, +// +// uuid_ is incremented when XPUGraph creates a MemPool +// as a result of a user not providing a pool. + +std::atomic MemPool::uid_{1}; +std::atomic MemPool::uuid_{1}; + +MemPool::MemPool( + XPUCachingAllocator::XPUAllocator* allocator, + bool is_user_created, + bool use_on_oom) + : allocator_(allocator), is_user_created_(is_user_created) { + if (is_user_created_) { + id_ = {0, uid_++}; + } else { + id_ = {uuid_++, 0}; + } + device_ = c10::xpu::current_device(); + XPUCachingAllocator::createOrIncrefPool(device_, id_, allocator); + if (use_on_oom) { + // XPU doesn't support use_on_oom yet + TORCH_WARN( + "XPUCachingAllocator::MemPool: use_on_oom is not supported on XPU"); + } +} + +MemPool::~MemPool() { + TORCH_INTERNAL_ASSERT(use_count() == 1); + XPUCachingAllocator::releasePool(device_, id_); + c10::xpu::XPUCachingAllocator::emptyCache(id_); // release cached blocks +} + +MempoolId_t MemPool::id() { + return id_; +} + +XPUCachingAllocator::XPUAllocator* MemPool::allocator() { + return allocator_; +} + +int MemPool::use_count() { + return XPUCachingAllocator::getPoolUseCount(device_, id_); +} + +c10::DeviceIndex MemPool::device() { + return device_; +} + +MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { + if (is_user_created) { + return {0, uid_++}; + } + return {uuid_++, 0}; +} + +} // namespace c10::xpu diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index bbb20a5b2ecdf..c55de309032e0 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -33,4 +33,59 @@ C10_XPU_API double getMemoryFraction(DeviceIndex device); C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); +class XPUAllocator; + +C10_XPU_API void createOrIncrefPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr); + +C10_XPU_API void beginAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + std::function filter); + +C10_XPU_API void endAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API void releasePool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API int getPoolUseCount( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +using c10::CaptureId_t; +using c10::MempoolId_t; +struct C10_XPU_API MemPool { + MemPool( + XPUCachingAllocator::XPUAllocator* allocator = nullptr, + bool is_user_created = true, + bool use_on_oom = false); + MemPool(const MemPool&) = delete; + MemPool(MemPool&&) = default; + MemPool& operator=(const MemPool&) = delete; + MemPool& operator=(MemPool&&) = default; + ~MemPool(); + + MempoolId_t id(); + XPUCachingAllocator::XPUAllocator* allocator(); + int use_count(); + c10::DeviceIndex device(); + static MempoolId_t graph_pool_handle(bool is_user_created = true); + + private: + static std::atomic uid_; + static std::atomic uuid_; + XPUCachingAllocator::XPUAllocator* allocator_; + bool is_user_created_; + MempoolId_t id_; + c10::DeviceIndex device_; +}; +} // namespace c10::xpu diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index bac1fa7daac01..5faad21f9f6cd 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -113,6 +113,12 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a") endif() endif() + # We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+ + if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + if(_existing_arch_flags MATCHES ".*compute_110.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a") + endif() + endif() if("${_arch}" STREQUAL "120a") if(_existing_arch_flags MATCHES ".*compute_120.*") list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") @@ -132,13 +138,13 @@ if(INTERN_BUILD_ATEN_OPS) _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" - "89;90a;100a;103a;120a;121a") + "89;90a;100a;103a;110a;120a;121a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" "90a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" - "90a;100a;103a") + "90a;100a;103a;110a") endif() diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 2018d5ec9370b..0349b09119cae 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -85,7 +85,7 @@ IF(NOT MKLDNN_FOUND) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") - IF(CPU_POWER) + IF(CPU_POWER OR CPU_RISCV) SET(DNNL_EXPERIMENTAL_UKERNEL OFF CACHE BOOL "" FORCE) ELSE() MESSAGE("-- Will build oneDNN UKERNEL") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 018bca837a5a8..7ecaff5109f42 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -83,8 +83,18 @@ find_package_and_print_version(HIP 1.0 MODULE) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) find_package_and_print_version(hip REQUIRED CONFIG) + if(HIP_VERSION) + # Check if HIP_VERSION contains a dash (e.g., "7.1.25421-32f9fa6ca5") + # and strip everything after it to get clean numeric version + string(FIND "${HIP_VERSION}" "-" DASH_POS) + if(NOT DASH_POS EQUAL -1) + string(SUBSTRING "${HIP_VERSION}" 0 ${DASH_POS} HIP_VERSION_CLEAN) + set(HIP_VERSION "${HIP_VERSION_CLEAN}") + endif() + message("HIP version: ${HIP_VERSION}") +endif() - # The rocm-core package was only introduced in ROCm 6.4, so we make it optional. +# The rocm-core package was only introduced in ROCm 6.4, so we make it optional. find_package(rocm-core CONFIG) # Some old consumer HIP SDKs do not distribute rocm_version.h, so we allow diff --git a/docs/source/community/governance.rst b/docs/source/community/governance.rst index cea24593dca83..ebfadf4e0f69b 100644 --- a/docs/source/community/governance.rst +++ b/docs/source/community/governance.rst @@ -132,7 +132,7 @@ The Process for Nomination * Each module has its own process. Please contact module maintainers for more information. However, if there is no process identified, you can file a request to the core - maintainers by submitting `this form `__. + maintainers by submitting `this form `__. Core maintainers are meeting every three months. * If you are submitting a request to the core maintainers, the information in your request must include the following items: diff --git a/docs/source/distributed.md b/docs/source/distributed.md index ca1fe3b5e9099..6840bbb893bf7 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the {func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs. + +### torch.distributed.debug HTTP Server + +The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can +be started by calling {func}`torch.distributed.debug.start_debug_server`. This +allows users to collect data across all workers at runtime. + +```{eval-rst} +.. automodule:: torch.distributed.debug + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ + :member-order: bysource + +``` + + ## Logging In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index 015d1d9ffda1a..ba39d80700f28 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -227,5 +227,6 @@ Low-Precision functions ScalingType SwizzleType + grouped_mm scaled_mm scaled_grouped_mm diff --git a/docs/source/torch.compiler_cudagraph_trees.md b/docs/source/torch.compiler_cudagraph_trees.md index eb137625ea746..f220086f82dc2 100644 --- a/docs/source/torch.compiler_cudagraph_trees.md +++ b/docs/source/torch.compiler_cudagraph_trees.md @@ -319,7 +319,7 @@ Trees, we don’t want to add unintended dependencies between iterations that wo to prematurely free memory from a prior invocation. Our heuristics are in inference we start a new iteration on each invocation for torch.compile, and in training we do the same so long as there is not a pending backward that has not been invoked. If those heuristics are wrong, you can mark the start of a new iteration with -[torch.compiler.mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone +[torch.compiler.cudagraph_mark_step_begin()](https://pytorch.org/docs/stable/generated/torch.compiler.cudagraph_mark_step_begin.html), or clone tensors of a prior iteration (outside of torch.compile) before you begin the next run. ### Comparisons diff --git a/functorch/benchmarks/pointwise_scorecard.py b/functorch/benchmarks/pointwise_scorecard.py index 5f46c0a74fc5d..9b910735f7d1e 100644 --- a/functorch/benchmarks/pointwise_scorecard.py +++ b/functorch/benchmarks/pointwise_scorecard.py @@ -233,7 +233,7 @@ def micros(s): args = shape()[:nargs] try: - if shape == medium_transpose: + if shape is medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") pw_op = pointwise_operator(operator) torch.testing.assert_close(operator(*args), pw_op(*args)) @@ -264,7 +264,7 @@ def micros(s): ) ) try: - if shape == medium_transpose: + if shape is medium_transpose: raise RuntimeError("pointwise_operator hangs on medium_transpose") if (operator, shape) in nope: raise RuntimeError("pointwise_operator fails on medium_transpose") diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000000000..3cb94acaa33f4 --- /dev/null +++ b/pylintrc @@ -0,0 +1,5 @@ +[MESSAGES CONTROL] + +# Disable the message, report, category or checker with the given id(s). +disable=all +enable=W0143 diff --git a/pyproject.toml b/pyproject.toml index 9986c6a9b7b6b..d9927122352f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -391,4 +391,5 @@ package = 'torch' ".spin/cmds.py:regenerate_version", ".spin/cmds.py:regenerate_type_stubs", ".spin/cmds.py:regenerate_clangtidy_files", + ".spin/cmds.py:regenerate_github_workflows", ] diff --git a/scripts/install_torchinductor_tpu_deps.sh b/scripts/install_torchinductor_tpu_deps.sh new file mode 100755 index 0000000000000..1fe374a8b7bfb --- /dev/null +++ b/scripts/install_torchinductor_tpu_deps.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# +# Install dependencies for TorchInductor on TPU. + +# Install dependencies from requirements.txt first +pip install -r requirements.txt + +# Install JAX nightly builds and other TPU dependencies +pip install --pre -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html jax==0.8.0.dev20251013 jaxlib==0.8.0.dev20251013 libtpu==0.0.25.dev20251012+nightly tpu-info==0.6.0 setuptools==78.1.0 # @lint-ignore diff --git a/test/complex_tensor/test_complex_tensor.py b/test/complex_tensor/test_complex_tensor.py new file mode 100644 index 0000000000000..dbb14d93f972a --- /dev/null +++ b/test/complex_tensor/test_complex_tensor.py @@ -0,0 +1,238 @@ +# Owner(s): ["module: complex"] +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist + + +# Support both when imported from elsewhere or directly as a file +try: + from .utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) +except ImportError: + from utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) + +from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) +from torch.testing._internal.common_utils import ( + run_tests, + TestGradients, + unMarkDynamoStrictTest, +) + + +if TYPE_CHECKING: + from torch.testing._internal.opinfo.core import OpInfo + +aten = torch.ops.aten + +SKIPS = { + Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.reciprocal, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.true_divide, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.any, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.all, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.allclose, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten._conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.cumprod, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_add, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.diagonal_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.flip, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_fill, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.rsub, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.ne, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.squeeze, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_select, variant=Variant.Distributed + ): "Sharding propagation failed", + Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support", +} + +EXTRA_KWARGS = { + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-6, + "atol": 1e-2, + }, +} + + +class TestComplexTensor(TestCase): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported, + allowed_dtypes=list(COMPLEX_DTYPES), + ) + def test_consistency(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + @ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_maybe_error(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + +@unMarkDynamoStrictTest +class TestComplexBwdGradients(TestGradients): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported_backward, + allowed_dtypes=[torch.complex128], + ) + def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None: + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=Variant.GradCheck, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!") + + with ComplexTensorMode(): + op.gradcheck_fast_mode = False + self._grad_test_helper(device, dtype, op, op.get_op()) + + +instantiate_device_type_tests(TestComplexTensor, globals()) +instantiate_device_type_tests(TestComplexBwdGradients, globals()) + + +if dist.is_available(): + from torch.testing._internal.common_distributed import MultiProcessTestCase + + @unMarkDynamoStrictTest + class TestComplexDistributed(TestCase, MultiProcessTestCase): + @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_distributed(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Distributed) + + instantiate_device_type_tests(TestComplexDistributed, globals()) + +if __name__ == "__main__": + run_tests() diff --git a/test/complex_tensor/utils.py b/test/complex_tensor/utils.py new file mode 100644 index 0000000000000..d2a1e1d312264 --- /dev/null +++ b/test/complex_tensor/utils.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from enum import auto, Enum +from typing import Any, TYPE_CHECKING + +import torch +import torch.distributed as dist +from torch._subclasses.complex_tensor._ops.common import ( + _as_complex_tensor, + _as_interleaved, + _get_op_name, + COMPLEX_OPS_TABLE, + COMPLEX_TO_REAL, + FORCE_TEST_LIST, + OpOverloadPacket, +) +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import TestCase as PytorchTestCase +from torch.utils._pytree import tree_flatten + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.distributed.tensor import DTensor + from torch.testing._internal.opinfo.core import OpInfo + +COMPLEX_DTYPES = set(COMPLEX_TO_REAL) + + +class Variant(Enum): + Op = auto() + GradCheck = auto() + Distributed = auto() + + +def _as_local(arg: DTensor | Any) -> torch.Tensor | Any: + if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)): + return arg + + return arg.full_tensor() + + +def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any: + if not isinstance(arg, torch.Tensor): + return arg + + return dist.tensor.DTensor.from_local(_as_complex_tensor(arg)) + + +TRANSFORM_FUNCS = { + Variant.Op: _as_complex_tensor, + Variant.Distributed: _as_complex_dtensor, +} + + +@dataclass(frozen=True, kw_only=True) +class Descriptor: + op: OpOverloadPacket + variant: Variant | None + device_type: str | None = field(default=None) + dtype: torch.dtype | None = field(default=None) + + def matches(self, other: Descriptor) -> bool: + fields1 = fields(self) + fields2 = fields(other) + if fields1 != fields2: + return False + + for f in fields1: + f1 = getattr(self, f.name) + f2 = getattr(other, f.name) + if f1 is not None and f2 is not None and f1 != f2: + return False + + return True + + +class TestCase(PytorchTestCase): + def assertSameResult( + self, + expected: Callable[[], Any], + actual: Callable[[], Any], + *args, + **kwargs, + ) -> None: + try: + result_e = expected() + exception_e = None + except Exception as e: # noqa: BLE001 + result_e = None + exception_e = e + + try: + result_a = actual() + exception_a = None + except Exception as e: # noqa: BLE001 + result_a = None + exception_a = e + + if (exception_e is None) != (exception_a is None): + if exception_a is not None and exception_e is None: + raise exception_a + self.assertIs( + type(exception_e), + type(exception_a), + f"\n{exception_e=}\n{exception_a=}", + ) + + if exception_e is None: + flattened_e, spec_e = tree_flatten(result_e) + flattened_a, spec_a = tree_flatten(result_a) + + self.assertEqual( + spec_e, + spec_a, + "Both functions must return a result with the same tree structure.", + ) + for value_e, value_a in zip(flattened_e, flattened_a, strict=True): + value_e = _as_interleaved(_as_local(value_e)) + value_a = _as_interleaved(_as_local(value_a)) + + self.assertEqual(value_e, value_a, *args, **kwargs) + + def check_consistency( + self, device: str, dtype, op: OpInfo, variant: Variant + ) -> None: + try: + from .test_complex_tensor import EXTRA_KWARGS, SKIPS + except ImportError: + from test_complex_tensor import EXTRA_KWARGS, SKIPS + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=variant, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + sample_inputs = op.sample_inputs(device, dtype) + transform_fn = TRANSFORM_FUNCS[variant] + + for sample_input in sample_inputs: + + def expected(sample_input=sample_input): + return op(sample_input.input, *sample_input.args, **sample_input.kwargs) + + subclass_sample = sample_input.transform(transform_fn) + + def actual(subclass_sample=subclass_sample): + return op( + subclass_sample.input, + *subclass_sample.args, + **subclass_sample.kwargs, + ) + + self.assertSameResult(expected, actual, **kwargs) + + +aten = torch.ops.aten + +complex_op_db = tuple( + filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db) +) + + +def get_overload_packet_from_name(name: str) -> OpOverloadPacket: + for domain_name in torch.ops: + op_namespace = getattr(torch.ops, domain_name) + op: OpOverloadPacket | None = getattr(op_namespace, name, None) + if op is not None: + return op + + raise RuntimeError(f"No op with {name=} found.") + + +force_test_names = set(map(_get_op_name, FORCE_TEST_LIST)) +implemented_op_names = ( + set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names +) +implemented_op_db = tuple( + filter(lambda op: op.name in implemented_op_names, complex_op_db) +) +force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db)) + +tested_op_names = {op.name for op in implemented_op_db} | { + op.name for op in force_test_op_db +} +non_tested_ops = { + op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names +} + + +# TODO (hameerabbasi): There are a number of ops that don't have any associated +# OpInfos. We still need to write tests for those ops. +if len(non_tested_ops) != 0: + import textwrap + import warnings + + list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops])) + warnings.warn( + "Not all implemented ops are tested. List of ops missing tests:" + f"\n{textwrap.indent(list_missing_ops, ' ')}", + UserWarning, + stacklevel=2, + ) diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 483814a0326d2..4146819e2f1a1 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -16,8 +16,10 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_layout.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_memoryformat.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_metaprogramming.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_scalartype.cpp diff --git a/test/cpp/aoti_abi_check/test_layout.cpp b/test/cpp/aoti_abi_check/test_layout.cpp new file mode 100644 index 0000000000000..7bb45a6897434 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_layout.cpp @@ -0,0 +1,20 @@ +#include + +#include + +TEST(TestLayout, TestLayout) { + using torch::headeronly::Layout; + constexpr Layout expected_layouts[] = { + torch::headeronly::kStrided, + torch::headeronly::kSparse, + torch::headeronly::kSparseCsr, + torch::headeronly::kMkldnn, + torch::headeronly::kSparseCsc, + torch::headeronly::kSparseBsr, + torch::headeronly::kSparseBsc, + torch::headeronly::kJagged, + }; + for (int8_t i = 0; i < static_cast(Layout::NumOptions); i++) { + EXPECT_EQ(static_cast(i), expected_layouts[i]); + } +} diff --git a/test/cpp/aoti_abi_check/test_memoryformat.cpp b/test/cpp/aoti_abi_check/test_memoryformat.cpp new file mode 100644 index 0000000000000..b0a584b15e299 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_memoryformat.cpp @@ -0,0 +1,23 @@ +#include + +#include + +TEST(TestMemoryFormat, TestMemoryFormat) { + using torch::headeronly::MemoryFormat; + constexpr MemoryFormat expected_memory_formats[] = { + MemoryFormat::Contiguous, + MemoryFormat::Preserve, + MemoryFormat::ChannelsLast, + MemoryFormat::ChannelsLast3d, + }; + for (int8_t i = 0; i < static_cast(MemoryFormat::NumOptions); i++) { + EXPECT_EQ(static_cast(i), expected_memory_formats[i]); + } +} + +TEST(TestMemoryFormat, get_contiguous_memory_format) { + using torch::headeronly::get_contiguous_memory_format; + using torch::headeronly::MemoryFormat; + + EXPECT_EQ(get_contiguous_memory_format(), MemoryFormat::Contiguous); +} diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index a58ac596d7cb2..5a5ea4a69f7c7 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1669,12 +1669,18 @@ TEST(NonDeterminismBackwardsCompatibility, BackwardsCompatibility) { "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; for (const std::string& op : nondeterministic_ops) { const c10::FunctionSchema& schema = torch::jit::parseSchema(op); diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu index 7773210a089ee..f8d87f60d9a2e 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu @@ -3,7 +3,11 @@ #include "tensor_accessor_kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp index 6278dca9f281d..4b17b113135e6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_empty.cpp @@ -10,14 +10,16 @@ using torch::stable::Tensor; Tensor my_empty( torch::headeronly::HeaderOnlyArrayRef size, std::optional dtype, + std::optional layout, std::optional device, - std::optional pin_memory) { - return empty(size, dtype, device, pin_memory); + std::optional pin_memory, + std::optional memory_format) { + return empty(size, dtype, layout, device, pin_memory, memory_format); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { m.def( - "my_empty(int[] size, ScalarType? dtype=None, Device? device=None, bool? pin_memory=None) -> Tensor"); + "my_empty(int[] size, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp new file mode 100644 index 0000000000000..c560fb0a60af9 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +using torch::stable::Tensor; + +torch::headeronly::HeaderOnlyArrayRef my_shape(Tensor t) { + return t.sizes(); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_shape(Tensor t) -> int[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_shape", TORCH_BOX(&my_shape)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp new file mode 100644 index 0000000000000..1b97f60882b0f --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_string_op.cpp @@ -0,0 +1,32 @@ +#include +#include +#include + +#include +#include + +using torch::stable::Tensor; + +std::tuple, int64_t> my_string_op(Tensor t, std::string_view accessor, std::string passthru) { + int64_t res; + if (accessor == "dim") { + res = t.dim(); + } else if (accessor == "size") { + res = t.size(0); + } else if (accessor == "stride") { + res = t.stride(0); + } else { + STD_TORCH_CHECK(false, "Unsupported accessor value: ", std::string(accessor).c_str()) + } + + auto vec = std::vector({std::string(accessor), std::to_string(res), passthru}); + return std::make_tuple(vec, res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_string_op(Tensor t, str accessor, str passthru) -> (str[], int)"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_string_op", TORCH_BOX(&my_string_op)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu new file mode 100644 index 0000000000000..439cb8e24ddb0 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/test_cublas_handle.cu @@ -0,0 +1,16 @@ +#include +#include + +void* my_get_curr_cuda_blas_handle() { + void* ret_handle; + TORCH_ERROR_CODE_CHECK(torch_get_current_cuda_blas_handle(&ret_handle)); + return ret_handle; +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_get_curr_cuda_blas_handle() -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic_2_10, CompositeExplicitAutograd, m) { + m.impl("my_get_curr_cuda_blas_handle", TORCH_BOX(&my_get_curr_cuda_blas_handle)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index db1a4fd43033c..b68839dc565c7 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -156,20 +156,24 @@ def test_get_num_threads() -> int: return torch.ops.libtorch_agnostic_2_10.test_get_num_threads.default() -def my_empty(size, dtype=None, device=None, pin_memory=None) -> Tensor: +def my_empty( + size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +) -> Tensor: """ - Creates an empty tensor with the specified size, dtype, device, and pin_memory. + Creates an empty tensor with the specified size, dtype, layout, device, pin_memory, and memory_format. Args: size: list[int] - size of the tensor to create dtype: ScalarType or None - data type of the tensor + layout: Layout or None - layout of the tensor device: Device or None - device on which to create the tensor pin_memory: bool or None - whether to use pinned memory + memory_format: MemoryFormat or None - memory format of the tensor Returns: Tensor - an uninitialized tensor with the specified properties """ return torch.ops.libtorch_agnostic_2_10.my_empty.default( - size, dtype, device, pin_memory + size, dtype, layout, device, pin_memory, memory_format ) @@ -199,6 +203,18 @@ def my_view(t, size) -> Tensor: return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size) +def my_shape(t) -> tuple[int]: + """ + Returns a shape of the input tensor. + + Args: + t: Tensor - input tensor + + Returns: tuple - shape of the imput tensor. + """ + return torch.ops.libtorch_agnostic_2_10.my_shape.default(t) + + def get_any_data_ptr(t, mutable) -> int: """ Return data pointer value of the tensor. @@ -223,3 +239,29 @@ def get_template_any_data_ptr(t, dtype, mutable) -> int: return torch.ops.libtorch_agnostic_2_10.get_template_any_data_ptr.default( t, dtype, mutable ) + + +def my_get_curr_cuda_blas_handle() -> int: + """ + Return the current cuBlasHandle_t pointer value. + """ + return torch.ops.libtorch_agnostic_2_10.my_get_curr_cuda_blas_handle.default() + + +def my_string_op(t, accessor, passthru) -> tuple[list[str], int]: + """ + The purpose of this op is to test inputting and outputting strings in a + stable custom op. This particular op takes in a Tensor, a string denoting + which tensor metadata API to call, and a pass through string to return a + string list and the value of the tensor metadata. + + If accessor is "size" or "stride", query along the 0th dim. + + Args: + t: Tensor - input tensor to query + accessor: str - which property to access ("dim", "size", or "stride") + passthru: str - a string that gets returned as the last element of the list + + Returns: tuple - (list of [accessor, value, passthru] as strings, value) + """ + return torch.ops.libtorch_agnostic_2_10.my_string_op.default(t, accessor, passthru) diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index ff2aeff5e932b..7bc37ba238139 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -45,7 +45,7 @@ def get_extension(): # allow including if torch.cuda.is_available(): extra_compile_args["cxx"].append("-DLAE_USE_CUDA") - extra_compile_args["nvcc"] = ["-O2"] + extra_compile_args["nvcc"] = ["-O2", "-DUSE_CUDA"] extension = CUDAExtension sources.extend(CSRC_DIR.glob("**/*.cu")) diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py index a094c57f8e614..05027a41b6715 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py @@ -22,9 +22,15 @@ from pathlib import Path from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -from torch.utils.cpp_extension import CUDA_HOME, include_paths as torch_include_paths +from torch.utils.cpp_extension import ( + CUDA_HOME, + include_paths as torch_include_paths, + ROCM_HOME, +) +GPU_HOME = CUDA_HOME or ROCM_HOME + # TODO: Fix this error in Windows: # numba.cuda.cudadrv.driver:driver.py:384 Call to cuInit results in CUDA_ERROR_NO_DEVICE if not IS_WINDOWS: @@ -42,8 +48,8 @@ def setUpClass(cls): f"-I{path}" for path in torch_include_paths(device_type="cpu") ] cls.cuda_includes = [] - if CUDA_HOME: - cuda_include_path = os.path.join(CUDA_HOME, "include") + if GPU_HOME: + cuda_include_path = os.path.join(GPU_HOME, "include") if os.path.exists(cuda_include_path): cls.cuda_includes = [f"-I{cuda_include_path}"] @@ -105,13 +111,13 @@ def _compile_cu_file( Compile a CUDA file with TORCH_TARGET_VERSION=2.9.0. Returns (success, error_message). """ - if not CUDA_HOME: - return False, "CUDA_HOME not set" + if not GPU_HOME: + return False, "one of CUDA_HOME and ROCM_HOME should be set but is not" torch_version_2_9 = "0x0209000000000000" cmd = [ - os.path.join(CUDA_HOME, "bin", "nvcc"), + os.path.join(GPU_HOME, "bin", "nvcc" if CUDA_HOME else "hipcc"), "-c", "-std=c++17", f"-DTORCH_TARGET_VERSION={torch_version_2_9}", @@ -120,6 +126,9 @@ def _compile_cu_file( *self.cuda_includes, ] + if ROCM_HOME: + cmd.extend(["-DUSE_ROCM=1"]) + cmd.extend([str(source_file), "-o", str(output_file)]) result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu index 88c19d0ebf062..1f549630262a6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu @@ -1,6 +1,10 @@ #include "kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp index 3d35b677cd208..3a6f2945d903c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -1,8 +1,275 @@ #include "OpenRegDeviceAllocator.h" +#include "OpenRegFunctions.h" + +#include +#include + +using namespace c10::CachingAllocator; namespace c10::openreg { -static OpenRegDeviceAllocator global_openreg_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); +constexpr size_t kAggregate = static_cast(StatType::AGGREGATE); + + +DeviceMemoryAllocator::DeviceMemoryAllocator(c10::DeviceIndex device_index) + : device_index_(device_index) {} + +void* DeviceMemoryAllocator::malloc(size_t nbytes) { + if (nbytes == 0) { + return nullptr; + } + + std::lock_guard lock(mutex_); + + void* data = nullptr; + auto ret = orMalloc(&data, nbytes); + + TORCH_CHECK( + ret == orSuccess && data != nullptr, + "Failed to allocate ", + nbytes, + " bytes on openreg device ", + device_index_, + ". ", + "Allocated: ", + stats_.allocated_bytes[0].current, + " bytes, ", + "Reserved: ", + stats_.reserved_bytes[0].current, + " bytes"); + + // Track allocation size for proper deallocation statistics + allocation_sizes_[data] = nbytes; + + // Update statistics + stats_.allocated_bytes[kAggregate].increase(nbytes); + stats_.reserved_bytes[kAggregate].increase(nbytes); + stats_.num_device_alloc++; + + return data; +} + +void DeviceMemoryAllocator::free(void* ptr) { + if (!ptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto ret = orFree(ptr); + + if (ret == orSuccess) { + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + size_t nbytes = it->second; + + stats_.allocated_bytes[kAggregate].decrease(nbytes); + stats_.reserved_bytes[kAggregate].decrease(nbytes); + stats_.num_device_free++; + + allocation_sizes_.erase(it); + } else { + TORCH_WARN( + "Successfully freed OpenReg memory pointer ", + ptr, + " on device ", + device_index_, + " that was not tracked by the allocator. " + "Statistics may be inaccurate."); + } + } else { + // orFree failed + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + TORCH_WARN( + "orFree failed for tracked pointer ", + ptr, + " with size ", + it->second, + " bytes on device ", + device_index_, + ". Return code: ", + ret, + ". Keeping tracking record - this may indicate a double-free or invalid pointer."); + } else { + TORCH_WARN( + "orFree failed for untracked pointer ", + ptr, + " on device ", + device_index_, + ". Return code: ", + ret, + ". This likely indicates a double-free or invalid pointer."); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats DeviceMemoryAllocator::getStats() { + std::lock_guard lock(mutex_); + return stats_; +} + +void DeviceMemoryAllocator::resetAccumulatedStats() { + std::lock_guard lock(mutex_); + + // Reset accumulated statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_accumulated(); + stats_.reserved_bytes[stat_type].reset_accumulated(); + stats_.active_bytes[stat_type].reset_accumulated(); + stats_.inactive_split_bytes[stat_type].reset_accumulated(); + stats_.requested_bytes[stat_type].reset_accumulated(); + } + + stats_.num_alloc_retries = 0; + stats_.num_ooms = 0; + stats_.num_sync_all_streams = 0; + stats_.num_device_alloc = 0; + stats_.num_device_free = 0; +} + +void DeviceMemoryAllocator::resetPeakStats() { + std::lock_guard lock(mutex_); + + // Reset peak statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_peak(); + stats_.reserved_bytes[stat_type].reset_peak(); + stats_.active_bytes[stat_type].reset_peak(); + stats_.inactive_split_bytes[stat_type].reset_peak(); + stats_.requested_bytes[stat_type].reset_peak(); + } + + stats_.oversize_allocations.reset_peak(); + stats_.oversize_segments.reset_peak(); +} + +namespace { + +OpenRegDeviceAllocator g_allocator; + +void deleteOpenRegMemory(void* ptr) { + g_allocator.freeMemory(ptr); +} + +} + +OpenRegDeviceAllocator::OpenRegDeviceAllocator() { + std::lock_guard lock(mutex_); + const auto device_count = c10::openreg::device_count(); + device_allocators_.resize(device_count); + for (const auto i : c10::irange(device_count)) { + device_allocators_[i] = std::make_unique(i); + } +} + + +at::DataPtr OpenRegDeviceAllocator::allocate(size_t nbytes) { + int current_device_index = -1; + auto ret = orGetDevice(¤t_device_index); + TORCH_CHECK(ret == orSuccess, "Failed to get current OpenReg device"); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + + void* data = nullptr; + if (nbytes > 0) { + // Allocate memory via device-specific allocator + data = device_allocators_[current_device_index]->malloc(nbytes); + + // Track which device owns this pointer + std::lock_guard lock(mutex_); + allocated_blocks_[data] = current_device_index; + } + + return {data, data, &deleteOpenRegMemory, curr_device}; +} + +at::DeleterFnPtr OpenRegDeviceAllocator::raw_deleter() const { + return &deleteOpenRegMemory; +} + +void OpenRegDeviceAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + auto ret = orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + TORCH_CHECK( + ret == orSuccess, "Failed to copy ", count, " bytes on openreg device"); +} + +bool OpenRegDeviceAllocator::initialized() { + std::lock_guard lock(mutex_); + return !device_allocators_.empty(); +} + +void OpenRegDeviceAllocator::freeMemory(void* ptr) { + if (!ptr) { + return; + } + + // Try to find which device owns this pointer + c10::DeviceIndex device_index = -1; + bool found_in_map = false; + + { + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(ptr); + if (it != allocated_blocks_.end()) { + device_index = it->second; + allocated_blocks_.erase(it); + found_in_map = true; + } + } + + if (found_in_map) { + // Pointer was tracked - free via device-specific allocator with stats + device_allocators_[device_index]->free(ptr); + } else { + // Pointer not tracked - might be already freed by storage or other path + // Try to free it directly via orFree without updating statistics + auto ret = orFree(ptr); + + // Only warn if orFree actually failed (not just "not found") + // In OpenReg's case, orFree returns orErrorUnknown if pointer not in registry + // which is expected for already-freed memory + if (ret != orSuccess && ret != orErrorUnknown) { + TORCH_WARN( + "orFree failed for untracked OpenReg memory pointer ", + ptr, + ". Error code: ", ret); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats OpenRegDeviceAllocator:: + getDeviceStats(c10::DeviceIndex device) { + return device_allocators_[device]->getStats(); +} + +void OpenRegDeviceAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + device_allocators_[device]->resetAccumulatedStats(); +} + +void OpenRegDeviceAllocator::resetPeakStats(c10::DeviceIndex device) { + device_allocators_[device]->resetPeakStats(); +} + +void OpenRegDeviceAllocator::emptyCache(MempoolId_t mempool_id) { + // OpenReg doesn't implement caching yet + // TODO: When caching is implemented, release all free blocks here +} + +void OpenRegDeviceAllocator::recordStream( + const DataPtr& ptr, + c10::Stream stream) { + // OpenReg doesn't track stream usage yet + // TODO: When stream support is added, track which streams are using this pointer +} +// ============ Global Registration ============ + +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &g_allocator); } // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h index c9aea4a913427..777926e02b18c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h @@ -1,43 +1,78 @@ -#include +#pragma once #include +#include #include +#include #include +#include +#include +#include +#include + namespace c10::openreg { -struct OpenRegDeviceAllocator final : at::Allocator { - OpenRegDeviceAllocator() = default; - - static void ReportAndDelete(void* ptr) { - if (!ptr) { - return; - } - orFreeHost(ptr); - } - - at::DataPtr allocate(size_t nbytes) override { - int current_device_index = -1; - orGetDevice(¤t_device_index); - - auto curr_device = - c10::Device(c10::DeviceType::PrivateUse1, current_device_index); - void* data = nullptr; - if (nbytes > 0) { - orMalloc(&data, nbytes); - TORCH_CHECK( - data, "Failed to allocator ", nbytes, " bytes on openreg device."); - } - return {data, data, &ReportAndDelete, curr_device}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - orMemcpy(dest, src, count, orMemcpyDeviceToDevice); - } + +class DeviceMemoryAllocator { + public: + explicit DeviceMemoryAllocator(c10::DeviceIndex device_index); + + DeviceMemoryAllocator(const DeviceMemoryAllocator&) = delete; + DeviceMemoryAllocator& operator=(const DeviceMemoryAllocator&) = delete; + + void* malloc(size_t nbytes); + + void free(void* ptr); + + c10::CachingDeviceAllocator::DeviceStats getStats(); + + void resetAccumulatedStats(); + + void resetPeakStats(); + + private: + c10::DeviceIndex device_index_; + + c10::CachingDeviceAllocator::DeviceStats stats_; + + std::unordered_map allocation_sizes_; + + std::recursive_mutex mutex_; +}; + + +class OpenRegDeviceAllocator final : public c10::DeviceAllocator { + public: + OpenRegDeviceAllocator(); + + at::DataPtr allocate(size_t nbytes) override; + at::DeleterFnPtr raw_deleter() const override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + + bool initialized() override; + void emptyCache(MempoolId_t mempool_id = {0, 0}) override; + void recordStream(const DataPtr& ptr, c10::Stream stream) override; + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + + + void freeMemory(void* ptr); + + private: + + // Per-device allocators + std::vector> device_allocators_; + + // Global mapping from pointer to device index + std::recursive_mutex mutex_; + ska::flat_hash_map allocated_blocks_; }; -} // namespace c10::openreg + + + +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py index 3d67e16a0f503..b4a64eedc5bfc 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py @@ -1,9 +1,392 @@ # Owner(s): ["module: PrivateUse1"] +import gc +import time + import torch + +import torch_openreg # noqa: F401 from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +class TestDeviceAllocator(TestCase): + """Test cases for OpenRegDeviceAllocator functionality.""" + + def setUp(self): + """Reset memory state before each test.""" + # Force garbage collection to ensure clean state + gc.collect() + # Note: We can't directly reset allocator stats without C++ API, + # but we can ensure tensors are properly released + + def test_basic_allocation(self): + """Test basic memory allocation with various sizes.""" + # Small allocation + x = torch.empty(100, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), 100) + # Large allocation + z = torch.empty(10000, device="openreg") + self.assertEqual(z.device.type, "openreg") + self.assertEqual(z.numel(), 10000) + # Multi-dimensional allocation + w = torch.empty(10, 20, 30, device="openreg") + self.assertEqual(w.device.type, "openreg") + self.assertEqual(w.shape, torch.Size([10, 20, 30])) + + def test_memory_lifecycle(self): + """Test complete memory allocation and deallocation lifecycle.""" + # Allocate tensor + x = torch.empty(1000, device="openreg") + self.assertEqual(x.device.type, "openreg") + + # Explicitly delete tensor + del x + gc.collect() + + # Allocate again to ensure memory was freed + y = torch.empty(1000, device="openreg") + self.assertEqual(y.device.type, "openreg") + del y + gc.collect() + + def test_tensor_copy_operations(self): + """Test memory operations during tensor copies.""" + # CPU to OpenReg + cpu_tensor = torch.randn(100) + openreg_tensor = cpu_tensor.to("openreg") + self.assertEqual(openreg_tensor.device.type, "openreg") + self.assertEqual(cpu_tensor.shape, openreg_tensor.shape) + + # OpenReg to CPU + back_to_cpu = openreg_tensor.to("cpu") + self.assertEqual(back_to_cpu.device.type, "cpu") + self.assertTrue(torch.allclose(cpu_tensor, back_to_cpu)) + + # OpenReg to OpenReg (clone) + cloned = openreg_tensor.clone() + self.assertEqual(cloned.device.type, "openreg") + self.assertTrue(torch.allclose(openreg_tensor.cpu(), cloned.cpu())) + + def test_inplace_operations(self): + """Test memory stability during inplace operations.""" + x = torch.ones(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Inplace addition + x.add_(1) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 2)) + + # Inplace multiplication + x.mul_(2) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 4)) + + def test_view_operations(self): + """Test that views share memory correctly.""" + x = torch.randn(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Reshape view + y = x.view(10, 10) + self.assertEqual(y.data_ptr(), original_data_ptr) + self.assertEqual(y.shape, torch.Size([10, 10])) + + # Slice view + z = x[10:20] + # Slices may have different data_ptr but should share storage + self.assertEqual(z.numel(), 10) + + def test_different_dtypes(self): + """Test allocation with different data types.""" + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + x = torch.empty(100, dtype=dtype, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.dtype, dtype) + self.assertEqual(x.numel(), 100) + + def test_tensor_resize(self): + """Test tensor resizing operations.""" + x = torch.empty(100, device="openreg") + _ = x.data_ptr() + + # Resize to smaller size (should reuse storage) + x.resize_(50) + self.assertEqual(x.numel(), 50) + # Storage should still be available + + # Resize to original size + x.resize_(100) + self.assertEqual(x.numel(), 100) + + def test_empty_cache_operation(self): + """Test empty cache functionality.""" + # Allocate some tensors + x = torch.empty(1000, device="openreg") + y = torch.empty(2000, device="openreg") + + # Delete tensors + del x, y + gc.collect() + + # Note: OpenRegDeviceAllocator.emptyCache is currently a no-op + # This test ensures it doesn't crash + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + def test_memory_format_allocation(self): + """Test allocation with different memory formats.""" + # Channels last format + x = torch.empty(2, 3, 4, 4, device="openreg", memory_format=torch.channels_last) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) + + # Contiguous format (default) + y = torch.empty( + 2, 3, 4, 4, device="openreg", memory_format=torch.contiguous_format + ) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.is_contiguous()) + + def test_large_allocation(self): + """Test large memory allocation.""" + # Allocate a large tensor (10MB approximately) + size = 10 * 1024 * 1024 // 4 # 10MB in float32 + x = torch.empty(size, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), size) + + def test_sequential_allocations_and_deallocations(self): + """Test sequential allocation and deallocation patterns.""" + for i in range(10): + x = torch.empty(1000 + i * 100, device="openreg") + self.assertEqual(x.device.type, "openreg") + # Let tensor go out of scope + del x + gc.collect() + + def test_allocation_with_requires_grad(self): + """Test allocation of tensors with gradient tracking.""" + x = torch.empty(100, device="openreg", requires_grad=True) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.requires_grad) + + y = torch.randn(100, device="openreg", requires_grad=True) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.requires_grad) + + def test_storage_operations(self): + """Test storage-level operations.""" + x = torch.randn(100, device="openreg") + storage = x.storage() + + # Verify storage is on correct device + self.assertTrue(storage.device.type == "openreg") + + # Verify storage size + self.assertGreaterEqual(storage.size(), x.numel()) + + def test_tensor_from_blob(self): + """Test creating tensors that reference existing memory.""" + x = torch.randn(100, device="openreg") + + # Create a view that references the same data + y = x.view_as(x) + + # They should share the same underlying storage + self.assertEqual(x.data_ptr(), y.data_ptr()) + + # Modifying one should affect the other + x.fill_(5.0) + self.assertTrue(torch.all(y == 5.0)) + + +class TestMemoryLeaks(TestCase): + """Test cases for detecting memory leaks in OpenRegDeviceAllocator.""" + + def setUp(self): + """Reset memory state before each test.""" + gc.collect() + time.sleep(0.1) # Allow time for cleanup + + def test_no_leak_simple_allocations(self): + """Test that simple allocations don't leak memory.""" + # Warm-up + for _ in range(10): + x = torch.empty(1000, device="openreg") + del x + gc.collect() + time.sleep(0.1) + + # Perform many allocations and deallocations + iterations = 1000 + for i in range(iterations): + x = torch.empty(1000, device="openreg") + del x + + if i % 100 == 0: + gc.collect() + + # Final cleanup + gc.collect() + time.sleep(0.1) + + # If there were leaks, this would have accumulated significant memory + # The test passes if no exception/crash occurred + + def test_no_leak_varying_sizes(self): + """Test that allocations of varying sizes don't leak.""" + iterations = 500 + sizes = [100, 500, 1000, 5000, 10000] + + for i in range(iterations): + size = sizes[i % len(sizes)] + x = torch.empty(size, device="openreg") + del x + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_copies(self): + """Test that tensor copies don't leak memory.""" + iterations = 300 + + for i in range(iterations): + # Create tensor + x = torch.randn(500, device="openreg") + + # Copy to CPU + cpu_copy = x.cpu() + + # Copy back to device + device_copy = cpu_copy.to("openreg") + + # Clone + cloned = device_copy.clone() + + # Delete all + del x, cpu_copy, device_copy, cloned + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_views(self): + """Test that tensor views don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.randn(1000, device="openreg") + + # Create various views + view1 = x.view(10, 100) + view2 = x[100:200] + view3 = x.reshape(20, 50) + + # Delete views and original + del view1, view2, view3, x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_inplace_operations(self): + """Test that inplace operations don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.ones(1000, device="openreg") + + # Multiple inplace operations + x.add_(1) + x.mul_(2) + x.div_(2) + x.sub_(1) + + del x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_gradients(self): + """Test that tensors with gradients don't leak.""" + iterations = 300 + + for i in range(iterations): + x = torch.randn(100, device="openreg", requires_grad=True) + y = torch.randn(100, device="openreg", requires_grad=True) + + # Operation that creates computation graph + z = x + y + + # Delete all + del x, y, z + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_repeated_large_allocations(self): + """Test repeated large allocations for memory leaks.""" + # Large tensor size (50MB) + size = 50 * 1024 * 1024 // 4 + iterations = 50 + + for i in range(iterations): + x = torch.empty(size, device="openreg") + del x + gc.collect() + time.sleep(0.05) # Allow time for cleanup + + # Final cleanup + gc.collect() + time.sleep(0.1) + + def test_leak_detection_with_statistics(self): + """Test memory leak detection using allocation patterns.""" + # This test verifies that after many alloc/dealloc cycles, + # the allocator properly frees memory + + num_cycles = 10 + allocations_per_cycle = 100 + + for cycle in range(num_cycles): + tensors = [] + + # Allocate many tensors + for i in range(allocations_per_cycle): + t = torch.empty(1000, device="openreg") + tensors.append(t) + + # Verify all allocated + self.assertEqual(len(tensors), allocations_per_cycle) + + # Delete all + tensors.clear() + gc.collect() + time.sleep(0.05) + + # Final verification - if there were leaks, memory would be exhausted + # The test passes if we can still allocate + final_tensor = torch.empty(10000, device="openreg") + self.assertEqual(final_tensor.device.type, "openreg") + del final_tensor + + class TestPinMemory(TestCase): @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_pin_memory(self): @@ -27,5 +410,110 @@ def test_pin_memory(self): self.assertTrue(pinned_untyped_storage.is_pinned("openreg")) +class TestMultiDeviceAllocation(TestCase): + """Test basic multi-device allocation functionality.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2, "This test requires 2 OpenReg devices") + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_allocation_on_device_1(self): + torch.openreg.set_device(1) + x = torch.empty(100, device="openreg:1") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.device.index, 1) + + def test_simultaneous_device_allocations(self): + """Test allocations on both devices simultaneously.""" + x = torch.empty(100, device="openreg:0") + y = torch.empty(200, device="openreg:1") + + self.assertEqual(x.device.index, 0) + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + def test_memory_isolation_between_devices(self): + """Test that memory allocations are isolated between devices.""" + + tensors_dev0 = [torch.empty(1000, device="openreg:0") for _ in range(10)] + tensors_dev1 = [torch.empty(1000, device="openreg:1") for _ in range(10)] + + # Verify all device 0 tensors are on device 0 + for t in tensors_dev0: + self.assertEqual(t.device.index, 0) + + # Verify all device 1 tensors are on device 1 + for t in tensors_dev1: + self.assertEqual(t.device.index, 1) + + # Pointers should be different + ptrs_dev0 = {t.data_ptr() for t in tensors_dev0} + ptrs_dev1 = {t.data_ptr() for t in tensors_dev1} + self.assertEqual( + len(ptrs_dev0 & ptrs_dev1), 0, "Devices should not share pointers" + ) + + def test_alternating_device_allocations(self): + """Test alternating allocations between devices.""" + tensors = [] + for i in range(20): + device_idx = i % 2 + t = torch.empty(100 + i, device=f"openreg:{device_idx}") + self.assertEqual(t.device.index, device_idx) + tensors.append(t) + + # Verify all tensors retained correct device assignment + for i, t in enumerate(tensors): + expected_device = i % 2 + self.assertEqual(t.device.index, expected_device) + + +class TestCrossDeviceOperations(TestCase): + """Test cross-device tensor operations.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2) + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_tensor_to_different_device(self): + """Test moving tensor from one device to another.""" + # Create on device 0 + x = torch.randn(100, device="openreg:0") + self.assertEqual(x.device.index, 0) + + # Move to device 1 + y = x.to("openreg:1") + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + # Values should be the same + self.assertTrue(torch.allclose(x.cpu(), y.cpu())) + + def test_bidirectional_device_transfer(self): + """Test transferring tensor back and forth between devices.""" + original = torch.randn(100, device="openreg:0") + original_cpu = original.cpu() + + # 0 -> 1 + on_dev1 = original.to("openreg:1") + self.assertTrue(torch.allclose(original_cpu, on_dev1.cpu())) + + # 1 -> 0 + back_to_dev0 = on_dev1.to("openreg:0") + self.assertTrue(torch.allclose(original_cpu, back_to_dev0.cpu())) + + if __name__ == "__main__": run_tests() diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp index b7501c81d7b7c..f8fc5946fd6e8 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/device_tests.cpp @@ -16,12 +16,22 @@ TEST_F(DeviceTest, GetDeviceCountValid) { EXPECT_EQ(count, 2); } +TEST_F(DeviceTest, GetDeviceCountNullptr) { + // orGetDeviceCount should reject null output pointers. + EXPECT_EQ(orGetDeviceCount(nullptr), orErrorUnknown); +} + TEST_F(DeviceTest, GetDeviceValid) { int device = -1; EXPECT_EQ(orGetDevice(&device), orSuccess); EXPECT_EQ(device, 0); } +TEST_F(DeviceTest, GetDeviceNullptr) { + // Defensive path: null output pointer must return an error. + EXPECT_EQ(orGetDevice(nullptr), orErrorUnknown); +} + TEST_F(DeviceTest, SetDeviceValid) { EXPECT_EQ(orSetDevice(1), orSuccess); @@ -38,4 +48,9 @@ TEST_F(DeviceTest, SetDeviceInvalidNegative) { EXPECT_EQ(orSetDevice(-1), orErrorUnknown); } +TEST_F(DeviceTest, SetDeviceInvalidTooLarge) { + // Device indices are 0-based and strictly less than DEVICE_COUNT (2). + EXPECT_EQ(orSetDevice(2), orErrorUnknown); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp index 416c50a863435..f45bd2690d41e 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/event_tests.cpp @@ -29,6 +29,13 @@ TEST_F(EventTest, EventCreateWithFlagsTiming) { EXPECT_EQ(orEventDestroy(event), orSuccess); } +TEST_F(EventTest, EventCreationNullptr) { + // Creation APIs must fail fast on null handles to mirror CUDA semantics. + EXPECT_EQ(orEventCreate(nullptr), orErrorUnknown); + EXPECT_EQ( + orEventCreateWithFlags(nullptr, orEventEnableTiming), orErrorUnknown); +} + TEST_F(EventTest, EventRecordAndSynchronize) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -44,6 +51,23 @@ TEST_F(EventTest, EventRecordAndSynchronize) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(EventTest, EventRecordInvalidArgs) { + orEvent_t event = nullptr; + EXPECT_EQ(orEventCreate(&event), orSuccess); + + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + // Record/sync/destroy should validate both stream and event pointers. + EXPECT_EQ(orEventRecord(nullptr, stream), orErrorUnknown); + EXPECT_EQ(orEventRecord(event, nullptr), orErrorUnknown); + EXPECT_EQ(orEventSynchronize(nullptr), orErrorUnknown); + EXPECT_EQ(orEventDestroy(nullptr), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(event), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(EventTest, EventElapsedTime) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -70,6 +94,60 @@ TEST_F(EventTest, EventElapsedTime) { EXPECT_EQ(orEventDestroy(end), orSuccess); } +// TODO: recording events to a stream is not allowed +// if the stream and the event are not on the same device +// Uncomment this test case after the issue is fixed. +// see #167819 +TEST_F(EventTest, DISABLED_EventElapsedTimeDifferentDevicesFails) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t start = nullptr; + orEvent_t end = nullptr; + EXPECT_EQ(orEventCreateWithFlags(&start, orEventEnableTiming), orSuccess); + + EXPECT_EQ(orEventRecord(start, stream), orSuccess); + + // Switch device before creating the end event to force a mismatch. + EXPECT_EQ(orSetDevice(1), orSuccess); + EXPECT_EQ(orEventCreateWithFlags(&end, orEventEnableTiming), orSuccess); + EXPECT_EQ(orSetDevice(0), orSuccess); + + EXPECT_EQ(orEventRecord(end, stream), orSuccess); + EXPECT_EQ(orEventSynchronize(start), orSuccess); + EXPECT_EQ(orEventSynchronize(end), orSuccess); + + float elapsed_ms = 0.0f; + EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(start), orSuccess); + EXPECT_EQ(orEventDestroy(end), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + +TEST_F(EventTest, EventElapsedTimeRequiresTimingFlag) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t start = nullptr; + orEvent_t end = nullptr; + EXPECT_EQ(orEventCreate(&start), orSuccess); + EXPECT_EQ(orEventCreate(&end), orSuccess); + + EXPECT_EQ(orEventRecord(start, stream), orSuccess); + EXPECT_EQ(orEventRecord(end, stream), orSuccess); + EXPECT_EQ(orEventSynchronize(start), orSuccess); + EXPECT_EQ(orEventSynchronize(end), orSuccess); + + // Without timing-enabled events, querying elapsed time must fail. + float elapsed_ms = 0.0f; + EXPECT_EQ(orEventElapsedTime(&elapsed_ms, start, end), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(start), orSuccess); + EXPECT_EQ(orEventDestroy(end), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(EventTest, StreamWaitEvent) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -85,4 +163,19 @@ TEST_F(EventTest, StreamWaitEvent) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(EventTest, StreamWaitEventInvalidArgs) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + orEvent_t event = nullptr; + EXPECT_EQ(orEventCreate(&event), orSuccess); + + // Validate both stream and event inputs for wait calls. + EXPECT_EQ(orStreamWaitEvent(nullptr, event, 0), orErrorUnknown); + EXPECT_EQ(orStreamWaitEvent(stream, nullptr, 0), orErrorUnknown); + + EXPECT_EQ(orEventDestroy(event), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp index e36ad4c0da3ee..3a5ccb54ad85a 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/memory_tests.cpp @@ -26,6 +26,12 @@ TEST_F(MemoryTest, AllocateAndFreeHost) { EXPECT_EQ(orFreeHost(ptr), orSuccess); } +TEST_F(MemoryTest, FreeNullptrIsNoop) { + // Freeing a nullptr should behave like CUDA: treated as a no-op success. + EXPECT_EQ(orFree(nullptr), orSuccess); + EXPECT_EQ(orFreeHost(nullptr), orSuccess); +} + TEST_F(MemoryTest, AllocateNullptr) { EXPECT_EQ(orMalloc(nullptr, 4096), orErrorUnknown); EXPECT_EQ(orMallocHost(nullptr, 4096), orErrorUnknown); @@ -86,6 +92,48 @@ TEST_F(MemoryTest, MemcpyInvalidKind) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, MemcpyInvalidCombinations) { + void *dev_src = nullptr, *dev_dst = nullptr; + EXPECT_EQ(orMalloc(&dev_src, 8), orSuccess); + EXPECT_EQ(orMalloc(&dev_dst, 8), orSuccess); + + char host_buf[8] = {}; + + // Deliberately pass mismatched kinds to ensure validation coverage. + EXPECT_EQ( + orMemcpy(host_buf, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown); + EXPECT_EQ( + orMemcpy(dev_dst, host_buf, 4, orMemcpyDeviceToHost), orErrorUnknown); + EXPECT_EQ( + orMemcpy(dev_dst, dev_src, 4, orMemcpyHostToDevice), orErrorUnknown); + + EXPECT_EQ(orFree(dev_src), orSuccess); + EXPECT_EQ(orFree(dev_dst), orSuccess); +} + +TEST_F(MemoryTest, MemcpyAsyncHostToDevice) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + const char host_src[] = "async"; + char host_dst[6] = {}; + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, sizeof(host_src)), orSuccess); + + // Async copies should complete once the stream is synchronized. + EXPECT_EQ( + orMemcpyAsync(dev_ptr, host_src, sizeof(host_src), orMemcpyHostToDevice, stream), + orSuccess); + EXPECT_EQ(orStreamSynchronize(stream), orSuccess); + EXPECT_EQ(orMemcpy( + host_dst, dev_ptr, sizeof(host_src), orMemcpyDeviceToHost), + orSuccess); + EXPECT_STREQ(host_dst, host_src); + + EXPECT_EQ(orFree(dev_ptr), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(MemoryTest, PointerAttributes) { void* dev_ptr = nullptr; EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess); @@ -102,6 +150,14 @@ TEST_F(MemoryTest, PointerAttributes) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, PointerAttributesInvalidArgs) { + // Attribute queries must fail on null inputs to avoid dereferencing. + char buffer[8] = {}; + orPointerAttributes attr{}; + EXPECT_EQ(orPointerGetAttributes(nullptr, buffer), orErrorUnknown); + EXPECT_EQ(orPointerGetAttributes(&attr, nullptr), orErrorUnknown); +} + TEST_F(MemoryTest, ProtectUnprotectDevice) { void* dev_ptr = nullptr; EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess); @@ -112,4 +168,24 @@ TEST_F(MemoryTest, ProtectUnprotectDevice) { EXPECT_EQ(orFree(dev_ptr), orSuccess); } +TEST_F(MemoryTest, ProtectReferenceCounting) { + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, 64), orSuccess); + + // Call unprotect/protect twice to exercise the refcount transitions. + EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryUnprotect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess); + EXPECT_EQ(orMemoryProtect(dev_ptr), orSuccess); + + EXPECT_EQ(orFree(dev_ptr), orSuccess); +} + +TEST_F(MemoryTest, DoubleFreeFails) { + void* dev_ptr = nullptr; + EXPECT_EQ(orMalloc(&dev_ptr, 32), orSuccess); + EXPECT_EQ(orFree(dev_ptr), orSuccess); + EXPECT_EQ(orFree(dev_ptr), orErrorUnknown); +} + } // namespace diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp index e91abaa1e7fe9..fbf5cb900a811 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/tests/stream_tests.cpp @@ -21,6 +21,11 @@ TEST_F(StreamTest, StreamCreateAndDestroy) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(StreamTest, StreamCreateNullptr) { + // Creation API should reject null double-pointer inputs. + EXPECT_EQ(orStreamCreate(nullptr), orErrorUnknown); +} + TEST_F(StreamTest, StreamCreateWithInvalidPriority) { orStream_t stream = nullptr; int min_p, max_p; @@ -30,6 +35,36 @@ TEST_F(StreamTest, StreamCreateWithInvalidPriority) { EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p + 1), orErrorUnknown); } +TEST_F(StreamTest, StreamCreateWithPriorityValidBounds) { + orStream_t stream = nullptr; + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + + // Lowest priority should be accepted. + EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, min_p), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); + + // Highest priority should also be accepted. + EXPECT_EQ(orStreamCreateWithPriority(&stream, 0, max_p), orSuccess); + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + +TEST_F(StreamTest, StreamDestroyNullptr) { + // Destroying nullptr should follow CUDA error behavior. + EXPECT_EQ(orStreamDestroy(nullptr), orErrorUnknown); +} + +TEST_F(StreamTest, StreamGetPriority) { + orStream_t stream = nullptr; + EXPECT_EQ(orStreamCreate(&stream), orSuccess); + + int priority = -1; + EXPECT_EQ(orStreamGetPriority(stream, &priority), orSuccess); + EXPECT_EQ(priority, 0); + + EXPECT_EQ(orStreamDestroy(stream), orSuccess); +} + TEST_F(StreamTest, StreamTaskExecution) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -43,6 +78,11 @@ TEST_F(StreamTest, StreamTaskExecution) { EXPECT_EQ(orStreamDestroy(stream), orSuccess); } +TEST_F(StreamTest, AddTaskToStreamNullptr) { + // Queueing work should fail fast if the stream handle is invalid. + EXPECT_EQ(openreg::addTaskToStream(nullptr, [] {}), orErrorUnknown); +} + TEST_F(StreamTest, StreamQuery) { orStream_t stream = nullptr; EXPECT_EQ(orStreamCreate(&stream), orSuccess); @@ -76,4 +116,18 @@ TEST_F(StreamTest, DeviceSynchronize) { EXPECT_EQ(orStreamDestroy(stream2), orSuccess); } +TEST_F(StreamTest, DeviceSynchronizeWithNoStreams) { + // Even without registered streams, device sync should succeed. + EXPECT_EQ(orDeviceSynchronize(), orSuccess); +} + +TEST_F(StreamTest, StreamPriorityRange) { + int min_p = -1; + int max_p = -1; + // OpenReg currently exposes only one priority level; verify the fixed range. + EXPECT_EQ(orDeviceGetStreamPriorityRange(&min_p, &max_p), orSuccess); + EXPECT_EQ(min_p, 0); + EXPECT_EQ(max_p, 0); +} + } // namespace diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 48ede590cecbf..dfb9b6b37f593 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( install_cpp_extension, IS_WINDOWS, + parametrize, run_tests, skipIfTorchDynamo, TestCase, @@ -618,7 +619,11 @@ def test_get_num_threads(self, device): self.assertEqual(num_threads, expected_num_threads) @skipIfTorchVersionLessThan(2, 10) - def test_my_empty(self, device): + @parametrize("layout", [None, torch.strided, torch.sparse_coo]) + @parametrize( + "memory_format", [None, torch.channels_last, torch.contiguous_format] + ) + def test_my_empty(self, device, layout, memory_format): import libtorch_agnostic_2_10 as libtorch_agnostic deterministic = torch.are_deterministic_algorithms_enabled() @@ -626,35 +631,80 @@ def test_my_empty(self, device): # set use_deterministic_algorithms to fill uninitialized memory torch.use_deterministic_algorithms(True) - size = [2, 3] - result = libtorch_agnostic.ops.my_empty(size, None, None, None) - expected = torch.empty(size) - self.assertEqual(result, expected, exact_device=True) + # Use 4D size for channels_last, 2D otherwise + size = [2, 3, 4, 5] if memory_format == torch.channels_last else [2, 3] + + # sparse_coo layout doesn't support memory_format parameter + if layout == torch.sparse_coo and memory_format is not None: + return + + # Test default parameters + result = libtorch_agnostic.ops.my_empty( + size, None, layout, None, None, memory_format + ) + expected = torch.empty(size, layout=layout, memory_format=memory_format) + self.assertEqual(result, expected, exact_device=True, exact_layout=True) + # Test with dtype result_float = libtorch_agnostic.ops.my_empty( - size, torch.float32, None, None + size, torch.float32, layout, None, None, memory_format + ) + expected_float = torch.empty( + size, + dtype=torch.float32, + layout=layout, + memory_format=memory_format, + ) + self.assertEqual( + result_float, expected_float, exact_device=True, exact_layout=True ) - expected_float = torch.empty(size, dtype=torch.float32) - self.assertEqual(result_float, expected_float, exact_device=True) + # Test with dtype and device result_with_device = libtorch_agnostic.ops.my_empty( - size, torch.float64, device, None + size, torch.float64, layout, device, None, memory_format ) expected_with_device = torch.empty( - size, dtype=torch.float64, device=device + size, + dtype=torch.float64, + layout=layout, + device=device, + memory_format=memory_format, ) self.assertEqual( - result_with_device, expected_with_device, exact_device=True + result_with_device, + expected_with_device, + exact_device=True, + exact_layout=True, ) - if device == "cuda": + # Verify layout if specified + if layout is not None: + self.assertEqual(result_with_device.layout, layout) + + # Verify memory format if specified + if memory_format == torch.channels_last: + self.assertTrue( + result_with_device.is_contiguous( + memory_format=torch.channels_last + ) + ) + elif memory_format == torch.contiguous_format: + self.assertTrue(result_with_device.is_contiguous()) + + # Test pin_memory on CUDA (only once, not for every parameter combination) + if device == "cuda" and layout is None and memory_format is None: result_pinned = libtorch_agnostic.ops.my_empty( - size, torch.float32, "cpu", True + [2, 3], torch.float32, None, "cpu", True, None ) expected_pinned = torch.empty( - size, dtype=torch.float32, device="cpu", pin_memory=True + [2, 3], dtype=torch.float32, device="cpu", pin_memory=True + ) + self.assertEqual( + result_pinned, + expected_pinned, + exact_device=True, + exact_layout=True, ) - self.assertEqual(result_pinned, expected_pinned) self.assertTrue(result_pinned.is_pinned()) finally: torch.use_deterministic_algorithms(deterministic) @@ -711,6 +761,15 @@ def test_my_view(self, device): expected_flat = t.view([-1]) self.assertEqual(result_flat, expected_flat) + @skipIfTorchVersionLessThan(2, 10) + def test_my_shape(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + expected = (3, 5) + t = torch.rand(*expected, device=device) + shape = libtorch_agnostic.ops.my_shape(t) + self.assertEqual(shape, expected) + def test_mv_tensor_accessor(self, device): import libtorch_agnostic_2_9 as libtorch_agnostic @@ -766,6 +825,40 @@ def test_get_template_any_data_ptr(self, device): t, rdtype, mutable ) + @skipIfTorchVersionLessThan(2, 10) + @onlyCUDA + def test_my_get_curr_cuda_blas_handle(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + res = libtorch_agnostic.ops.my_get_curr_cuda_blas_handle() + expected = torch.cuda.current_blas_handle() + self.assertEqual(res, expected) + + @skipIfTorchVersionLessThan(2, 10) + def test_my_string_op(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + t = torch.empty(3, 4, 5, device=device) + + dim_vec, result_dim = libtorch_agnostic.ops.my_string_op(t, "dim", "ice") + self.assertEqual(dim_vec, ["dim", str(t.dim()), "ice"]) + self.assertEqual(result_dim, t.dim()) + + size_vec, result_size = libtorch_agnostic.ops.my_string_op( + t, "size", "cream" + ) + self.assertEqual(size_vec, ["size", str(t.size(0)), "cream"]) + self.assertEqual(result_size, t.size(0)) + + stride_vec, result_stride = libtorch_agnostic.ops.my_string_op( + t, "stride", "cake" + ) + self.assertEqual(stride_vec, ["stride", str(t.stride(0)), "cake"]) + self.assertEqual(result_stride, t.stride(0)) + + with self.assertRaisesRegex(RuntimeError, "Unsupported accessor value: "): + libtorch_agnostic.ops.my_string_op(t, "invalid", "") + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index ad3064608960d..076c4de69f44f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -428,7 +428,14 @@ def test_manual_reshard_with_reshard_after_forward_false(self): @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571 def test_set_reduce_scatter_divide_factor(self): self.run_subtests( - {"divide_factor": [self.world_size * 2, self.world_size]}, + { + "divide_factor": [self.world_size * 2, self.world_size], + "mesh_shape": [ + (self.world_size,), + (self.world_size // 2, 2), + (self.world_size, 1), + ], + }, self._test_set_reduce_scatter_divide_factor, ) self.run_subtests( @@ -436,18 +443,31 @@ def test_set_reduce_scatter_divide_factor(self): self._test_set_reduce_scatter_divide_factor_mixed_prevision, ) - def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): + def _test_set_reduce_scatter_divide_factor( + self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int] + ): torch.manual_seed(42) model_args = ModelArgs(dropout_p=0.0, weight_tying=False) model = Transformer(model_args) ref_model = copy.deepcopy(model).to(device_type) ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner") + mesh = init_device_mesh( + device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names + ) for module in model.modules(): if isinstance(module, TransformerBlock): - fully_shard(module, reshard_after_forward=False) - model = fully_shard(model, reshard_after_forward=False) + fully_shard(module, reshard_after_forward=False, mesh=mesh) + model = fully_shard(model, reshard_after_forward=False, mesh=mesh) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) + + # Get ref_model params which should have the specific division factor applied + block_params = set() + for ref_mod in ref_model.modules(): + if isinstance(ref_mod, TransformerBlock): + block_params.update(ref_mod.parameters()) + non_block_params = set(ref_model.parameters()) - block_params torch.manual_seed(42 + self.rank) inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type) @@ -456,16 +476,18 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): ref_loss = ref_model(inp).sum() ref_loss.backward() for param in ref_model.parameters(): - param.grad.mul_(1.0 / divide_factor) + factor = divide_factor if param in non_block_params else self.world_size + param.grad.mul_(1.0 / factor) dist.all_reduce(param.grad) loss = model(inp).sum() loss.backward() ref_optim.step() optim.step() - ref_optim.zero_grad() - optim.zero_grad() self.assertEqual(ref_loss, loss) + # Check parity before calling zero_grad so that grads are also checked check_sharded_parity(self, ref_model, model) + ref_optim.zero_grad() + optim.zero_grad() def _test_set_reduce_scatter_divide_factor_mixed_prevision( self, divide_factor: float @@ -484,7 +506,7 @@ def _test_set_reduce_scatter_divide_factor_mixed_prevision( fully_shard(mlp, mp_policy=mp_policy) model = fully_shard(model, mp_policy=mp_policy) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) torch.manual_seed(42 + self.rank) inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype) diff --git a/test/distributed/_composable/test_replicate_training.py b/test/distributed/_composable/test_replicate_training.py index 076a5e3760ff5..3dc908a8b1afe 100644 --- a/test/distributed/_composable/test_replicate_training.py +++ b/test/distributed/_composable/test_replicate_training.py @@ -678,7 +678,9 @@ def _test_train_parity_with_activation_checkpointing( test_device_type: str, ): assert checkpoint_impl in ("composable", "utils", "wrapper") - testing_compile = replicate != torch.distributed._composable.replicate_with_fsdp + testing_compile = ( + replicate is not torch.distributed._composable.replicate_with_fsdp + ) if testing_compile and checkpoint_impl == "composable": return torch.manual_seed(42) diff --git a/test/distributed/_shard/sharding_spec/test_sharding_spec.py b/test/distributed/_shard/sharding_spec/test_sharding_spec.py index 73018c1025619..37ad69075068f 100644 --- a/test/distributed/_shard/sharding_spec/test_sharding_spec.py +++ b/test/distributed/_shard/sharding_spec/test_sharding_spec.py @@ -490,6 +490,69 @@ def test_check_overlapping(self): with self.assertRaisesRegex(ValueError, "overlap"): validate_non_overlapping_shards_metadata(shards) + shards = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[0, 5], + shard_sizes=[5, 5], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[5, 0], + shard_sizes=[5, 5], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="cuda:3", + ), + ] + validate_non_overlapping_shards_metadata(shards) + + shards = [ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 5], + shard_sizes=[5, 5], + placement="cuda:1", + ), + ] + validate_non_overlapping_shards_metadata(shards) + + shards = [ + ShardMetadata( + shard_offsets=[0, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:0", + ), + ShardMetadata( + shard_offsets=[5, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:1", + ), + ShardMetadata( + shard_offsets=[10, 0, 0], + shard_sizes=[5, 5, 5], + placement="cuda:2", + ), + ShardMetadata( + shard_offsets=[10, 3, 0], + shard_sizes=[5, 5, 5], + placement="cuda:3", + ), + ] + with self.assertRaisesRegex(ValueError, "overlap"): + validate_non_overlapping_shards_metadata(shards) + # Custom ShardingSpec, an simple example to do grid sharding @dataclass diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index 9b31cf3b1755b..d6168a71f752b 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -6,6 +6,7 @@ import pickle import socket import tempfile +import unittest from contextlib import contextmanager from urllib3.connection import HTTPConnection @@ -15,7 +16,13 @@ TORCH_WORKER_SERVER_SOCKET, worker_main, ) -from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase +from torch.monitor import _WaitCounter +from torch.testing._internal.common_utils import ( + IS_FBCODE, + requires_cuda, + run_tests, + TestCase, +) class UnixHTTPConnection(HTTPConnection): @@ -216,6 +223,68 @@ def test_get_handler_names(self) -> None: names = _get_handler_names() self.assertIn("ping", names) + @unittest.skipIf(IS_FBCODE, "disabled in FBCODE") + def test_wait_counter_values(self) -> None: + """ + Test that WaitCounter values are properly tracked and returned by the handler. + + Note: This test may trigger an ASAN heap-use-after-free error during process + shutdown due to static destruction order issues with boost regex in the logging + framework. The test assertions pass successfully before this shutdown error occurs. + """ + with local_worker_server() as pool: + # Create and use a WaitCounter with a specific name + counter_name = "test_counter" + counter = _WaitCounter(counter_name) + + # Use the counter multiple times to generate metrics + # Note: Using minimal/no sleep to avoid timing issues + for i in range(3): + with counter.guard(): + pass # Minimal work + + # Query the wait counter values + resp = pool.request("POST", "/handler/wait_counter_values") + self.assertEqual(resp.status, 200) + + # Parse the JSON response + data = json.loads(resp.data) + # Should be a dictionary + self.assertIsInstance(data, dict) + + # Verify our test counter appears in the response + self.assertIn( + counter_name, + data, + f"Counter '{counter_name}' not found in response. Available counters: {list(data.keys())}", + ) + + # Verify the counter has expected metrics + counter_data = data[counter_name] + self.assertIn("active_count", counter_data) + self.assertIn("total_calls", counter_data) + self.assertIn("total_time_us", counter_data) + self.assertIn("max_time_us", counter_data) + + # Verify the counter was called 3 times + self.assertEqual( + counter_data["total_calls"], + 3, + f"Expected 3 calls, got {counter_data['total_calls']}", + ) + + # Verify active_count is 0 (no active waiters) + self.assertEqual( + counter_data["active_count"], + 0, + f"Expected 0 active, got {counter_data['active_count']}", + ) + + # total_time_us and max_time_us may be 0 or very small for fast operations + # Just verify they exist and are non-negative + self.assertGreaterEqual(counter_data["total_time_us"], 0) + self.assertGreaterEqual(counter_data["max_time_us"], 0) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index 0acb530f441fc..8cc5698cd19aa 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -73,7 +73,7 @@ def forward(self, a, b, c=None, d=None, **kwargs): ]: with self.subTest(wrapper=wrapper): model = wrapper(MyModel()) - if wrapper == offload_wrapper: + if wrapper is offload_wrapper: self.assertTrue(isinstance(model, OffloadWrapper)) else: self.assertTrue(isinstance(model, CheckpointWrapper)) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 00479cf0935b9..67f8e1af9abbd 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -30,11 +30,13 @@ ) sys.exit(0) - -_DISTRIBUTED_STATE_DICT_IMPLS = { +# NB: this iterable needs to be orderd as otherwise different ranks may run with +# conflicting settings when e.g., @parametrize(_DISTRIBUTED_STATE_DICT_IMPLS) is +# used to decorate tests +_DISTRIBUTED_STATE_DICT_IMPLS = ( StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, -} +) class TestDistributedCheckpoint(FSDPTest): diff --git a/test/distributed/tensor/debug/test_comm_mode.py b/test/distributed/tensor/debug/test_comm_mode.py index c87164750c684..d122a9f716fcd 100644 --- a/test/distributed/tensor/debug/test_comm_mode.py +++ b/test/distributed/tensor/debug/test_comm_mode.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_distributed import requires_nccl +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore @@ -14,6 +14,9 @@ c10d_functional = torch.ops.c10d_functional c10d_ops = torch.ops.c10d +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) class TestCommMode(TestCase): @@ -28,7 +31,7 @@ def setUp(self): dist.init_process_group( backend="fake", rank=1, world_size=self.world_size, store=store ) - self.device_type = "cuda" if torch.cuda.is_available() else "cpu" + self.device_type = device_type self.world_pg = dist.distributed_c10d._get_default_group() def checksAssert(self, comm_mode, key, expected_value, expected_total_value): @@ -111,12 +114,12 @@ def f(x, y): self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_comm_mode_with_c10d(self): - if not torch.cuda.is_available(): + if not torch.accelerator.is_available(): return - inp = torch.rand(2, 8, 16).cuda() + inp = torch.rand(2, 8, 16).to(device_type) all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) comm_mode = CommDebugMode() diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 1ba4bccf3696d..c0625d37c6dad 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -76,7 +76,7 @@ def test_debug_mode_mm(self): _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] - (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P(sum) aten::sum(dt$6: f32[8, 32]| S(0)) aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) @@ -179,8 +179,8 @@ def test_debug_mode_backward(self): (dt: f32[8, 8]| S(0)) aten::sum(dt: f32[8, 8]| S(0)) aten::sum(t: f32[1, 8]) - torch._tensor.backward(dt: f32[]| P, gradient=None, retain_graph=None, create_graph=False, inputs=None) - aten::ones_like(dt: f32[]| P, pin_memory=False, memory_format=torch.preserve_format) + torch._tensor.backward(dt: f32[]| P(sum), gradient=None, retain_graph=None, create_graph=False, inputs=None) + aten::ones_like(dt: f32[]| P(sum), pin_memory=False, memory_format=torch.preserve_format) aten::ones_like(t: f32[], pin_memory=False, memory_format=torch.preserve_format) aten::expand(dt: f32[]| R, [8, 8]) aten::expand(t: f32[], [8, 8]) @@ -189,9 +189,9 @@ def test_debug_mode_backward(self): aten::clone(t: f32[8, 1]) aten::_to_copy(t: f32[8, 1], dtype=torch.float32, layout=torch.strided, device=cpu) redistribute_input(t: f32[8, 8], trace: R->S(0)) - aten::detach(t: f32[8, 1]) aten::split.Tensor(t: f32[8, 8], 1) aten::clone(t: f32[1, 8]) + aten::detach(t: f32[8, 1]) aten::_to_copy(t: f32[1, 8], dtype=torch.float32, layout=torch.strided, device=cpu) aten::detach(t: f32[1, 8])""", ) @@ -215,12 +215,8 @@ def test_debug_mode_densor_redistribution_trace(self): debug_mode.debug_string(), """\ aten::mm(dt: f32[128, 8]| S(0)[0]S(0)[1], dt: f32[8, 128]| S(1)[0]S(1)[1]) - redistribute_input(0, S(0)[0]S(0)[1] -> S(0)R) - redistribute_input(t: f32[16, 8], trace: S(0)[0]S(0)[1]->S(0)R) - _c10d_functional::all_gather_into_tensor(t: f32[16, 8], 2, 3) - _c10d_functional::wait_tensor(t: f32[32, 8]) - redistribute_input(1, S(1)[0]S(1)[1] -> RS(1)) - redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR->RS(1)) + redistribute_input(1, S(1)[0]S(1)[1] -> RR) + redistribute_input(t: f32[8, 16], trace: S(1)[0]S(1)[1]->S(1)R->RR) _c10d_functional::all_gather_into_tensor(t: f32[8, 16], 2, 3) _c10d_functional::wait_tensor(t: f32[16, 16]) aten::chunk(t: f32[16, 16], 2) @@ -229,11 +225,9 @@ def test_debug_mode_densor_redistribution_trace(self): _c10d_functional::wait_tensor(t: f32[32, 32]) aten::chunk(t: f32[32, 32], 4) aten::cat(['t: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]', 't: f32[8, 32]'], 1) - aten::chunk(t: f32[8, 128], 2, 1) - aten::clone(t: f32[8, 64]) - aten::mm(t: f32[32, 8], t: f32[8, 64]) - aten::sum(dt: f32[128, 128]| S(0)S(1)) - aten::sum(t: f32[32, 64])""", + aten::mm(t: f32[16, 8], t: f32[8, 128]) + aten::sum(dt: f32[128, 128]| S(0)[0]S(0)[1]) + aten::sum(t: f32[16, 128])""", ) def test_debug_mode_einsum(self): @@ -253,38 +247,38 @@ def test_debug_mode_einsum(self): self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| PR, dt: f32[8, 4, 4]| RP) - aten::unsqueeze(dt: f32[16, 6, 8]| PR, 3) + torch.functional.einsum(bld,dnh->blnh, dt: f32[16, 6, 8]| P(sum)R, dt: f32[8, 4, 4]| RP(sum)) + aten::unsqueeze(dt: f32[16, 6, 8]| P(sum)R, 3) aten::unsqueeze(t: f32[16, 6, 8], 3) - aten::unsqueeze(dt: f32[16, 6, 8, 1]| PR, 4) + aten::unsqueeze(dt: f32[16, 6, 8, 1]| P(sum)R, 4) aten::unsqueeze(t: f32[16, 6, 8, 1], 4) - aten::permute(dt: f32[16, 6, 8, 1, 1]| PR, [0, 1, 3, 4, 2]) + aten::permute(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [0, 1, 3, 4, 2]) aten::permute(t: f32[16, 6, 8, 1, 1], [0, 1, 3, 4, 2]) - aten::unsqueeze(dt: f32[8, 4, 4]| RP, 3) + aten::unsqueeze(dt: f32[8, 4, 4]| RP(sum), 3) aten::unsqueeze(t: f32[8, 4, 4], 3) - aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP, 4) + aten::unsqueeze(dt: f32[8, 4, 4, 1]| RP(sum), 4) aten::unsqueeze(t: f32[8, 4, 4, 1], 4) - aten::permute(dt: f32[8, 4, 4, 1, 1]| RP, [3, 4, 1, 2, 0]) + aten::permute(dt: f32[8, 4, 4, 1, 1]| RP(sum), [3, 4, 1, 2, 0]) aten::permute(t: f32[8, 4, 4, 1, 1], [3, 4, 1, 2, 0]) - aten::permute(dt: f32[16, 6, 1, 1, 8]| PR, [0, 1, 4, 2, 3]) + aten::permute(dt: f32[16, 6, 1, 1, 8]| P(sum)R, [0, 1, 4, 2, 3]) aten::permute(t: f32[16, 6, 1, 1, 8], [0, 1, 4, 2, 3]) - aten::view(dt: f32[16, 6, 8, 1, 1]| PR, [1, 96, 8]) + aten::view(dt: f32[16, 6, 8, 1, 1]| P(sum)R, [1, 96, 8]) aten::view(t: f32[16, 6, 8, 1, 1], [1, 96, 8]) - aten::permute(dt: f32[1, 1, 4, 4, 8]| RP, [4, 2, 3, 0, 1]) + aten::permute(dt: f32[1, 1, 4, 4, 8]| RP(sum), [4, 2, 3, 0, 1]) aten::permute(t: f32[1, 1, 4, 4, 8], [4, 2, 3, 0, 1]) - aten::view(dt: f32[8, 4, 4, 1, 1]| RP, [1, 8, 16]) + aten::view(dt: f32[8, 4, 4, 1, 1]| RP(sum), [1, 8, 16]) aten::view(t: f32[8, 4, 4, 1, 1], [1, 8, 16]) - aten::bmm(dt: f32[1, 96, 8]| PR, dt: f32[1, 8, 16]| RP) - redistribute_input(0, PR -> S(2)[0]S(2)[1]) - redistribute_input(t: f32[1, 96, 8], trace: PR->S(2)R->S(2)[0]S(2)[1]) + aten::bmm(dt: f32[1, 96, 8]| P(sum)R, dt: f32[1, 8, 16]| RP(sum)) + redistribute_input(0, P(sum)R -> S(2)[0]S(2)[1]) + redistribute_input(t: f32[1, 96, 8], trace: P(sum)R->S(2)R->S(2)[0]S(2)[1]) aten::chunk(t: f32[1, 96, 8], 4, 2) aten::cat(['t: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]', 't: f32[1, 96, 2]']) _c10d_functional::reduce_scatter_tensor(t: f32[4, 96, 2], sum, 4, 1) _c10d_functional::wait_tensor(t: f32[1, 96, 2]) aten::chunk(t: f32[1, 96, 2], 2, 2) aten::clone(t: f32[1, 96, 1]) - redistribute_input(1, RP -> S(1)[0]S(1)[1]) - redistribute_input(t: f32[1, 8, 16], trace: RP->S(1)P->S(1)[0]S(1)[1]) + redistribute_input(1, RP(sum) -> S(1)[0]S(1)[1]) + redistribute_input(t: f32[1, 8, 16], trace: RP(sum)->S(1)P(sum)->S(1)[0]S(1)[1]) aten::chunk(t: f32[1, 8, 16], 4, 1) aten::clone(t: f32[1, 2, 16]) aten::chunk(t: f32[1, 2, 16], 2, 1) @@ -292,11 +286,11 @@ def test_debug_mode_einsum(self): _c10d_functional::reduce_scatter_tensor(t: f32[2, 1, 16], sum, 2, 3) _c10d_functional::wait_tensor(t: f32[1, 1, 16]) aten::bmm(t: f32[1, 96, 1], t: f32[1, 1, 16]) - aten::view(dt: f32[1, 96, 16]| PP, [16, 6, 1, 4, 4]) + aten::view(dt: f32[1, 96, 16]| P(sum)P(sum), [16, 6, 1, 4, 4]) aten::view(t: f32[1, 96, 16], [16, 6, 1, 4, 4]) - aten::permute(dt: f32[16, 6, 1, 4, 4]| PP, [0, 1, 3, 4, 2]) + aten::permute(dt: f32[16, 6, 1, 4, 4]| P(sum)P(sum), [0, 1, 3, 4, 2]) aten::permute(t: f32[16, 6, 1, 4, 4], [0, 1, 3, 4, 2]) - aten::view(dt: f32[16, 6, 4, 4, 1]| PP, [16, 6, 4, 4]) + aten::view(dt: f32[16, 6, 4, 4, 1]| P(sum)P(sum), [16, 6, 4, 4]) aten::view(t: f32[16, 6, 4, 4, 1], [16, 6, 4, 4])""", ) @@ -582,6 +576,32 @@ def test_check_structure_mismatches(self): with self.assertRaisesRegex(ValueError, "Log lengths don't match"): DebugMode.check_hash_mismatches(dm1.logs, dm3.logs) + @unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(0).total_memory < 2**26, + "Being conservative, test peak memory is 25MB?", + ) + def test_tensor_hash_waits_on_collective(self): + # test that hashing collectives gives correct results + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + local_tensor = torch.ones(2**18, device=self.device_type) + dt = DTensor.from_local(local_tensor, mesh, [Shard(0)], run_check=False) + + with DebugMode() as debug_mode, DebugMode.log_tensor_hashes(): + dt.redistribute(mesh, [Replicate()]) + + # Find all_gather hash + all_gather_logs = [ + op + for op in debug_mode.logs + if isinstance(op, _OpCall) + and op.op == torch.ops._c10d_functional.all_gather_into_tensor.default + ] + self.assertEqual(len(all_gather_logs), 1) + actual_hash = all_gather_logs[0].log["hash"] + self.assertEqual(actual_hash, float(local_tensor.numel() * self.world_size)) + def test_pretty_print_dtensor_make_fx(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) diff --git a/test/distributed/tensor/experimental/test_register_sharding.py b/test/distributed/tensor/experimental/test_register_sharding.py index 1cfd7af243b58..0f2f208689608 100644 --- a/test/distributed/tensor/experimental/test_register_sharding.py +++ b/test/distributed/tensor/experimental/test_register_sharding.py @@ -116,6 +116,34 @@ def custom_argmax_sharding(x, dim, keepdim): self.assertTrue(dist_y.placements[0].is_shard(dim=0)) self.assertEqual(dist_y.full_tensor(), local_y) + @with_comms + def test_register_sharding_for_tensor_kwargs(self): + mesh = self.build_device_mesh() + x = torch.randn(4, 4, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Replicate()]) + + @register_sharding(aten.min.dim_min) + def min_dim_strategy(x, dim, keepdim, min, min_indices): + all_replicate = ( + [Replicate(), Replicate()], + [Replicate(), None, None, Replicate(), Replicate()], + ) + return [all_replicate] + + value = torch.randn(4, 1, device=self.device_type) + indices = torch.randn(4, 1, device=self.device_type).long() + value_dt = distribute_tensor(value, mesh, [Replicate()]) + indices_dt = distribute_tensor(indices, mesh, [Replicate()]) + + result = torch.min(x_dt, dim=1, keepdim=True, out=(value_dt, indices_dt)) + + self.assertIsInstance(result[0], DTensor) + self.assertIsInstance(result[1], DTensor) + + expected_values, expected_indices = torch.min(x, dim=1, keepdim=True) + self.assertEqual(result[0].full_tensor(), expected_values) + self.assertEqual(result[1].full_tensor(), expected_indices) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_api.py b/test/distributed/tensor/test_api.py index e1790f4829907..12897ee822e87 100644 --- a/test/distributed/tensor/test_api.py +++ b/test/distributed/tensor/test_api.py @@ -79,7 +79,13 @@ def test_distribute_tensor_rank(self): dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec) self.assertEqual(dist_tensor.placements[0].dim, 1) - placement_combs = [[Shard(0)], [Shard(1)], [Replicate()]] + placement_combs = [ + [Shard(0)], + [Shard(1)], + [Replicate()], + [Partial(reduce_op="sum")], + [Partial(reduce_op="avg")], + ] if not self.is_local_tensor_enabled: # test src_data_rank == 1 @@ -125,6 +131,10 @@ def test_distribute_tensor_errors(self): shard_spec = [Shard(0)] distribute_tensor(tensor_to_distribute, device_mesh, shard_spec) + with self.assertRaisesRegex(ValueError, "conversion is not supported"): + new_spec = [Replicate(), Partial(reduce_op="prod")] + distribute_tensor(tensor_to_distribute, device_mesh, new_spec) + with self.assertRaisesRegex(RuntimeError, "distribute leaf tensor"): shard_spec = [Shard(0)] global_tensor = torch.randn(*tensor_shape, requires_grad=True) diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 0a06bd66df5e8..ed1cc60802e70 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -14,6 +14,7 @@ Shard, ) from torch.nn import functional as F +from torch.testing._internal.common_cuda import with_tf32_off from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -230,6 +231,7 @@ def test_conv3d(self): out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)]) self.assertEqual(out_dt, out) + @with_tf32_off @with_comms def test_conv2d_no_bias_compile(self): """Test Conv2d with bias=False in compile mode (Issue #167091) @@ -262,7 +264,7 @@ def conv_fn(x, w): self.assertEqual(result_compiled.shape, torch.Size([1, 8, 5, 5])) # Verify numerical correctness - torch.testing.assert_close(result_compiled.to_local(), result_eager.to_local()) + self.assertEqual(result_compiled.to_local(), result_eager.to_local()) @with_comms def test_conv2d_no_bias_backward(self): diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index e99734c6b8437..c47ff79091493 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -658,11 +658,11 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh = self.build_device_mesh() - # construct from a cpu local tensor with cuda device mesh - # should automatically convert the dist tensor to cuda + # construct from a cpu local tensor with gpu device mesh + # should automatically convert the dist tensor to gpu placements = [Shard(0)] local_tensor = torch.randn(3, 3) dist_tensor = DTensor.from_local(local_tensor, mesh, placements) @@ -711,7 +711,7 @@ def test_dtensor_api_device_mesh_context_manager(self): @with_comms def test_dtensor_2d_mesh(self): mesh_tensor = torch.arange(self.world_size).reshape(2, 4) - # construct a cuda device mesh + # construct a gpu device mesh mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 2d device mesh and test if works @@ -733,7 +733,7 @@ def test_dtensor_2d_mesh(self): @with_comms def test_device_mesh_nd(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 3d device mesh and test if works @@ -1064,8 +1064,8 @@ def _create_tensor(self, size): # Keep everything deterministic. torch.manual_seed(0) tensor = torch.rand(size) - if self.device_type == "cuda": - return tensor.cuda() + if self.device_type != "cpu": + return tensor.to(self.device_type) else: return tensor diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index ddba3150b05fb..e58b6dda658f3 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -39,6 +39,7 @@ RowwiseParallel, ) from torch.distributed.tensor.placement_types import _StridedShard +from torch.testing._internal.common_device_type import skipXPUIf from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( @@ -47,8 +48,6 @@ run_tests, skipIfHpu, skipIfTorchDynamo, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -64,6 +63,54 @@ dev_type = torch.device(get_devtype()) +class PytreeTuple: + """ + Tuple-like values that are treated as leaves of a PyTree. + """ + + def __init__(self, *values): + self._values = tuple(values) + + def __repr__(self): + pr = repr(self._values)[1:-1] + return f"{type(self).__name__}({pr})" + + def __getitem__(self, i): + return self._values[i] + + def __iter__(self): + return iter(self._values) + + def __len__(self): + return len(self._values) + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self._values == other._values + elif isinstance(other, tuple): + return self._values == other + return False + + def __hash__(self) -> int: + return hash(self._values) + + def __add__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*self, *other) + raise NotImplementedError(type(other)) + + def __radd__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*other, *self) + raise NotImplementedError(type(other)) + + def index(self, value): + return self._values.index(value) + + def count(self, value): + return self._values.count(value) + + class SimpleModel(nn.Module): def __init__(self, device): super().__init__() @@ -95,6 +142,10 @@ def extract_graph(fx_g, _, graph_cell): partition_fn=min_cut_rematerialization_partition, ) +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh): """ @@ -141,7 +192,7 @@ def tearDown(self): @property def device_type(self) -> str: - return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu" + return device_type @property def world_size(self) -> int: @@ -160,9 +211,9 @@ def fn(x): res = fn(x) res.to_local().sum().backward() - @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available") def test_dtensor_basic_export(self): - mesh = DeviceMesh("cuda", torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) param = torch.randn(4, 4) param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False) @@ -188,10 +239,10 @@ def forward(self, x): ) self.assertExpectedInline( str(ep.graph_module.code).strip(), - """\ + f"""\ def forward(self, b_buffer, x): _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None - to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None + to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None view_as = torch.ops.aten.view_as.default(to, to); to = None dtensor___init__0 = self.dtensor___init__0 dtensor_const_func_spec0 = self.dtensor_const_func_spec0 @@ -206,10 +257,10 @@ def forward(self, b_buffer, x): # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), - """\ + f"""\ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None - _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None + _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None @@ -377,6 +428,7 @@ def fn(x): self.assertEqual(res, ref) @skipIfHpu + @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981") def test_dtensor_dynamic_loss_parallel_log_softmax(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -763,6 +815,37 @@ def fn(x): # this fails with an inductor stride assert out_dt.to_local().sum().backward() + def test_dynamo_to_local_grad_placements_sequence(self): + placements = PytreeTuple([Shard(0)]) + + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + + def test_dynamo_to_local_grad_placements_sequence_intermediate(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + placements = PytreeTuple([Shard(0)]) + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + def test_dynamo_to_local_kwargs(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -815,13 +898,13 @@ def fn(x, y, z): out = layer_norm.permute(0, 2, 1) return out - x = torch.randn(4, 2, 4, requires_grad=True, device="cuda") + x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type) x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False) - y = torch.randn(4, requires_grad=True, device="cuda") + y = torch.randn(4, requires_grad=True, device=self.device_type) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) - z = torch.randn(4, requires_grad=True, device="cuda") + z = torch.randn(4, requires_grad=True, device=self.device_type) z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) @@ -919,7 +1002,7 @@ def test_dtensor_dynamo_device_mesh_attrs(self): # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn def fn(x_dt): - if x_dt.device_mesh.device_type == "cuda": + if x_dt.device_mesh.device_type == f"{self.device_type}": return x_dt + 1 else: return x_dt + 2 @@ -1051,7 +1134,7 @@ def forward(self, input): model = FakeTransformer().to(self.device_type) - tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) + tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",)) # apply sequence parallel parallel_plan = { diff --git a/test/distributed/tensor/test_dtensor_dispatch_overhead.py b/test/distributed/tensor/test_dtensor_dispatch_overhead.py index 7d08725205e60..ab9b578b80f93 100644 --- a/test/distributed/tensor/test_dtensor_dispatch_overhead.py +++ b/test/distributed/tensor/test_dtensor_dispatch_overhead.py @@ -10,6 +10,7 @@ import torch from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import distribute_tensor, DTensor, Shard +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -65,6 +66,7 @@ class DistOpDispatchOverHead(DTensorTestBase): def world_size(self) -> int: return 4 + @skip_if_lt_x_gpu(4) @with_comms def test_dtensor_add_op_dispatch_overhead(self): if torch.cuda.is_available(): diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index bd75668ab4856..4a88cf9a6e0b1 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import contextlib -import unittest import torch import torch.distributed as dist @@ -357,7 +356,6 @@ def test_export_parallelize_module_with_dtensor_input( # aot_export_joint_with_descriptors on strict-exported exported_program.module() # is producing a joint graph with backward region missing - @unittest.expectedFailure def test_strict_export_parallelize_module_with_dtensor_input(self): self._run_test(strict_export_and_aot_export_joint_with_descriptors) diff --git a/test/distributed/tensor/test_dtensor_ops.py b/test/distributed/tensor/test_dtensor_ops.py index df51152a90307..c3eb791fd0e41 100644 --- a/test/distributed/tensor/test_dtensor_ops.py +++ b/test/distributed/tensor/test_dtensor_ops.py @@ -435,8 +435,6 @@ def repurpose_ops(op_db, base_test_name, derived_test_name): xfail("signal.windows.nuttall"), xfail("signal.windows.kaiser"), xfail("stack"), - xfail("std"), - xfail("std", "unbiased"), xfail("std_mean"), xfail("std_mean", "unbiased"), xfail("stft"), @@ -727,7 +725,7 @@ def run_mean(self): self.assertEqual(full_tensor, tensor.mean(dim=reduce_dim)) if is_evenly_shardable: - self.assertTrue("P->R" in debug_mode.debug_string()) + self.assertTrue("P(avg)->R" in debug_mode.debug_string()) else: self.assertTrue("S(0)->R" in debug_mode.debug_string()) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 56321806477b9..2922c5ff85960 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -1037,6 +1037,31 @@ def test_matching_partial_reduction_ops(self): self.assertTrue(out_with_redistribute.placements[0].is_replicate()) self.assertEqual(out_without_redistribute, out_with_redistribute) + @skip_if_lt_x_gpu(4) + @with_comms + def test_std(self): + mesh = DeviceMesh(self.device_type, torch.arange(4).reshape(2, 2)) + rank = self.rank + comm_mode = CommDebugMode() + + global_tensor = map_local_for_rank( + rank, + lambda rank: torch.tensor( + [[-20.0, -18.0, -12.0, 0.0], [-20.0, -18.0, -8.0, 4.0]] + ), + ) + + dt = distribute_tensor(global_tensor, mesh, [Shard(0), Shard(1)]) + + with comm_mode: + res = dt.std(dim=1) + expected_answer = torch.tensor([9.0, 11.0]) + + self.assertEqual(comm_mode.get_total_counts(), 1) + self.assertEqual(comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1) + self.assertEqual(res.placements, [Shard(0), Replicate()]) + self.assertEqual(res.full_tensor(), expected_answer) + DistMathOpsTestWithLocalTensor = create_local_tensor_test_class( DistMathOpsTest, diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 6e3dd23c44210..65c8cc6f36af4 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -549,7 +549,7 @@ def test_tensordot_shampoo(self): ], ) def test_grouped_mm(self, kwargs): - # TODO: torch._grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) + # TODO: torch.nn.functional.grouped_mm can take inputs of dimension (2D, 3D) x (2D, 3D) # More tests need to be added. device_mesh = self.build_device_mesh() comm_mode = CommDebugMode() @@ -574,8 +574,8 @@ def test_grouped_mm(self, kwargs): ) offs = torch.tensor([16, 64], device=self.device_type, dtype=torch.int32) - h = torch._grouped_mm(inp, w1, offs=offs) - out = torch._grouped_mm(h, w2, offs=offs) + h = F.grouped_mm(inp, w1, offs=offs) + out = F.grouped_mm(h, w2, offs=offs) dist_inp = distribute_tensor(inp, device_mesh, kwargs["inp_placements"]) # colwise sharded @@ -585,8 +585,8 @@ def test_grouped_mm(self, kwargs): dist_offs = distribute_tensor(offs, device_mesh, [Replicate()]) with comm_mode: - dist_h = torch._grouped_mm(dist_inp, dist_w1, offs=dist_offs) - dist_out = torch._grouped_mm(dist_h, dist_w2, offs=dist_offs) + dist_h = F.grouped_mm(dist_inp, dist_w1, offs=dist_offs) + dist_out = F.grouped_mm(dist_h, dist_w2, offs=dist_offs) self.assertEqual( comm_mode.get_total_counts(), kwargs["expected_comm_counts_fwd"] ) diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 139f5fb61fac8..e1d3f96e9e5f4 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -30,11 +30,13 @@ EinsumDims, gen_einsum_strategies, ) -from torch.distributed.tensor._ops.utils import ( - register_op_strategy, - replicate_op_strategy, +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import replicate_op_strategy +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, + CommDebugMode, ) -from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -378,6 +380,37 @@ def test_bmm_strategies(self): ) self.assertFalse(output_sharding.needs_redistribute) + def test_redistribute_cost_with_order(self): + mesh_2d = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 2) + ) + + # Source: Shard on dim 0 across all three mesh dimensions + source_placement = (Shard(0), Shard(0)) + + # Target: Replicate on first mesh dimension, shard on others + # This requires 2 allgathers, one on dim=0 and one on dim=1 + replicate_mesh_dim0 = (Replicate(), Shard(0)) + + # Target: Replicate on second mesh dimension, shard on others + # This requires 1 allgather on dim=1 + replicate_mesh_dim1 = (Shard(0), Replicate()) + + global_tensor = torch.randn(4, 4) + global_tensor_meta = extract_tensor_meta(global_tensor) + + source_spec = DTensorSpec(mesh_2d, source_placement, global_tensor_meta) + target_spec_dim0 = DTensorSpec(mesh_2d, replicate_mesh_dim0, global_tensor_meta) + target_spec_dim1 = DTensorSpec(mesh_2d, replicate_mesh_dim1, global_tensor_meta) + + # Calculate costs for allgather on each mesh dimension + cost_mesh_dim0 = redistribute_cost(source_spec, target_spec_dim0) + cost_mesh_dim1 = redistribute_cost(source_spec, target_spec_dim1) + + # Cost increases with earlier mesh dimensions due to the way + # mesh dimensions are ordered (outer to inner in device hierarchy) + self.assertGreater(cost_mesh_dim0, cost_mesh_dim1) + # -------------Test op strategy registration------------- # custom op without List[Tensor] as input @@ -479,7 +512,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool: @@ -645,6 +679,28 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) +class TestStrategyOperation(DTensorTestBase): + @property + def world_size(self): + return 2 + + @with_comms + def test_cache_clean(self): + mesh = self.build_device_mesh() + test_op = torch.ops.mylib.numpy_sin + x = torch.randn(2, device=self.device_type) + y = torch.randn(2, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Shard(0)]) + y_dt = distribute_tensor(y, mesh, [Shard(0)]) + with op_strategy_context(test_op.default, replicate_op_strategy): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + with self.assertRaisesRegex( + NotImplementedError, + f"Operator {test_op.default} does not have a sharding strategy registered", + ): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + + DistTensorReplicateStrategyRegistrationTestWithLocalTensor = ( create_local_tensor_test_class( DistTensorReplicateStrategyRegistrationTest, diff --git a/test/distributed/tensor/test_pointwise_ops.py b/test/distributed/tensor/test_pointwise_ops.py index d2c4e7dea06b4..9d35e10f24ba8 100644 --- a/test/distributed/tensor/test_pointwise_ops.py +++ b/test/distributed/tensor/test_pointwise_ops.py @@ -148,6 +148,30 @@ def test_partial_add(self): d_3 = d_1 + d_2 self.assertTrue(d_3._spec.placements[0].is_partial()) + def test_partial_replicate_add(self): + device_mesh = self.build_device_mesh() + comm_mode = CommDebugMode() + + for reduce_op in ("sum", "avg"): + d_1 = DTensor.from_local( + torch.rand(2, 2), + device_mesh, + [Partial(reduce_op=reduce_op)], + ) + d_2 = DTensor.from_local( + torch.rand(2, 1), + device_mesh, + [Replicate()], + run_check=True, + ) + + with comm_mode: + d_3 = d_1 + d_2 + + self.assertEqual(comm_mode.get_total_counts(), 0) + self.assertEqual(d_3.placements, (Partial(reduce_op=reduce_op),)) + self.assertEqual(d_3.full_tensor(), d_1.full_tensor() + d_2.full_tensor()) + def test_activations(self): device_mesh = self.build_device_mesh() self._run_sharded_elementwise_ops( @@ -247,6 +271,7 @@ def test_dropout_backward(self): ), ) + @skip_unless_torch_gpu def test_dropout_errors(self): device_mesh = self.build_device_mesh() with self.assertRaisesRegex(RuntimeError, "supported"): diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 61b88ee169e2e..4bcddc198836b 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._random as random +from torch.distributed._local_tensor import LocalTensor, maybe_run_for_local_tensor from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.distributed_c10d import broadcast_object_list from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import ( DeviceMesh, @@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, skip_if_lt_x_gpu, skip_unless_torch_gpu, @@ -34,9 +35,12 @@ from torch.utils._typing_utils import not_none -def get_generator_seed_for_device_type(device_type: str) -> int: - device_module = torch.get_device_module(device_type) - return device_module.get_rng_state()[:8].view(torch.int64).item() +def get_generator_seed_for_device_type(device_type: str): + from torch.distributed._local_tensor import ( + get_generator_seed_for_device_type as _get_seed, + ) + + return _get_seed(device_type) class DistTensorRandomInitTest(DTensorTestBase): @@ -134,9 +138,6 @@ def test_meta_tensor_init(self): torch.empty(*size, device="meta"), device_mesh, [Replicate()] ) - # the tensor slice on the current rank - self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024) - # Test 1: enable the distribute region for RNG (by default) self.assertTrue(meta_dtensor.is_meta) # Tensor meta init @@ -150,16 +151,23 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same because they're replicated - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - gathered_local_tensors[self_slice, :], - gathered_local_tensors[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(gathered_local_tensors, rank): + # the tensor slice on the current rank + self_slice = slice(1024 * rank, 1024 * rank + 1024) + + # compare with local tensors from other ranks + for other_rank in range(self.world_size): + # the RNG result on each rank are the same because they're replicated + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) + self.assertEqual( + gathered_local_tensors[self_slice, :], + gathered_local_tensors[other_slice, :], + ) + + compute_rankwise_if_local_tensor(gathered_local_tensors.wait(), self.rank) # Test 2: disable the distribute region for RNG self.assertTrue(meta_dtensor.is_meta) @@ -175,15 +183,7 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same even without the help of DTensor's RNG infra, - # since the default RNG is the same across ranks. - if self.rank != other_rank: - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - local_tensor[self_slice, :], local_tensor[other_slice, :] - ) + compute_rankwise_if_local_tensor(local_tensor.wait(), self.rank) @with_comms @skip_unless_torch_gpu @@ -224,13 +224,17 @@ def test_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -277,13 +281,17 @@ def test_fsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) class DistTensorRandomOpTest(DTensorTestBase): @@ -291,9 +299,14 @@ class DistTensorRandomOpTest(DTensorTestBase): @skip_unless_torch_gpu def test_rng_tracker_init(self): torch.manual_seed(self.rank) - object_list = [torch.initial_seed()] - broadcast_object_list(object_list) - seed_from_rank_0 = int(object_list[0]) + seed_local = ( + torch.zeros_like(torch.empty(1), device=self.device_type) + + torch.initial_seed() + ) + torch.distributed.broadcast(seed_local, src=0) + # if localtensor, it should automaticall reconcile after the broadcast + # since all virtual ranks should have rank 0's initial_seed() + seed_from_rank_0 = seed_local device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) # seed synchronization now does NOT happen after the first `distribute_tensor` @@ -344,15 +357,19 @@ def test_manual_seed(self): @with_comms @skip_unless_torch_gpu def test_manual_seed_submesh(self): - # the current rank is not a part of the mesh - single_rank_device_mesh = DeviceMesh( - self.device_type, [(self.rank + 1) % self.world_size] - ) - with self.assertRaisesRegex( - RuntimeError, - "manual_seed requires the current rank to be a part of the device mesh", - ): - manual_seed(self.rank, single_rank_device_mesh) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(rank): + # the current rank is not a part of the mesh + single_rank_device_mesh = DeviceMesh( + self.device_type, [(rank + 1) % self.world_size], _rank=rank + ) + with self.assertRaisesRegex( + RuntimeError, + "manual_seed requires the current rank to be a part of the device mesh", + ): + manual_seed(rank, single_rank_device_mesh) + + compute_rankwise_if_local_tensor(self.rank) @with_comms @skip_unless_torch_gpu @@ -394,7 +411,7 @@ def test_pipeline_parallel_manual_seed(self): for other_rank in range(self.world_size): if self.rank != other_rank: self.assertNotEqual( - spmd_dtensor.to_local(), + spmd_dtensor, tensor_gather[2 * other_rank : 2 * (other_rank + 1), :], ) @@ -428,16 +445,20 @@ def test_deterministic_dropout_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_unless_torch_gpu @@ -454,16 +475,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have a different local tensor for shard placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertNotEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertNotEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) # we should set manual seed to the same value on all SPMD ranks torch.manual_seed(0) @@ -472,16 +497,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor for replicate placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -539,7 +568,12 @@ def test_deterministic_uniform_2d(self): shard_linear_idx = random._rng_tracker._calc_shard_linear_idx( shard_coord, shard_size ) - self.assertEqual(shard_linear_idx, shard_index[self.rank]) + + @maybe_run_for_local_tensor + def check_shard_index(shard_linear_idx, rank): + self.assertEqual(shard_linear_idx, shard_index[rank]) + + check_shard_index(shard_linear_idx, self.rank) # compute local size and offset _, local_shard_offset = compute_local_shape_and_global_offset( @@ -578,16 +612,46 @@ def test_deterministic_uniform_2d(self): # allgather the local tensors full_tensor = dtensor.full_tensor() - # compare local tensor with each other shard - for other_local_shard in local_shard_comb: - other_local_shard_offset, _ = zip(*other_local_shard) - slice_idx = [ - slice(offset, offset + size) for offset, size in other_local_shard - ] - if local_shard_offset == other_local_shard_offset: - self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) - else: - self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + full_tensor = ( + full_tensor.reconcile() + if isinstance(full_tensor, LocalTensor) + else full_tensor + ) + + @maybe_run_for_local_tensor + def blockwise_iter_if_localtensor(local_tensor, local_shard_offset): + # compare local tensor with each other shard + for other_local_shard in local_shard_comb: + other_local_shard_offset, _ = zip(*other_local_shard) + slice_idx = [ + slice(offset, offset + size) + for offset, size in other_local_shard + ] + if local_shard_offset == other_local_shard_offset: + self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) + else: + self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + + blockwise_iter_if_localtensor(local_tensor, local_shard_offset) + + def test_philox_state_seed_roundtrip(self): + """ + Test that _PhiloxState seed can be read and re-set without error. + + This test addresses the issue where reading a seed value from the state + (which uses uint64 view) and then re-setting it would fail with: + OverflowError: can't convert negative int to unsigned + + The fix ensures the seed getter uses uint64 view, preventing negative + values from appearing when the high bit is set. + """ + from torch.distributed.tensor._random import _PhiloxState + + state = torch.zeros(16, dtype=torch.uint8, device="cpu") + philox = _PhiloxState(state) + test_seed = 2**63 + 42 # This has the sign bit set when viewed as int64 + philox.seed = test_seed + philox.seed = philox.seed class DistTensorRandomOpsTest3D(DTensorTestBase): @@ -641,22 +705,46 @@ def test_hsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - shard_dim_0_len = self.world_size // 4 - for other_rank in range(self.world_size): - other_rank_dim_0_start = other_rank * shard_dim_0_len - other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len - if self.rank % 4 != other_rank % 4: - self.assertNotEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) - else: - self.assertEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) + weight_gather = weight_gather.wait() + + weight_gather = ( + weight_gather.reconcile() + if isinstance(weight_gather, LocalTensor) + else weight_gather + ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, rank): + # verify the weights are initialized differently on all ranks + shard_dim_0_len = self.world_size // 4 + for other_rank in range(self.world_size): + other_rank_dim_0_start = other_rank * shard_dim_0_len + other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len + if rank % 4 != other_rank % 4: + self.assertNotEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + else: + self.assertEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + + compute_rankwise_if_local_tensor(weight_local, self.rank) + + +DistTensorRandomInitTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomInitTest, +) + +DistTensorRandomOpTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpTest, +) + +DistTensorRandomOpsTest3DWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpsTest3D, +) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 381660e47927d..ebb2c5f01668f 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -21,14 +21,17 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + use_min_cost_redistribution_plan, +) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -509,6 +512,7 @@ def test_redistribute_uneven_sharding(self): dt_full_tensor = dt.full_tensor() self.assertEqual(dt_full_tensor, input_tensor) + @skip_if_lt_x_gpu(4) @with_comms @parametrize("dtype", [torch.float32, torch.cfloat]) def test_redistribute_shard_dim_change(self, dtype): @@ -541,7 +545,7 @@ def test_redistribute_shard_dim_change(self, dtype): local_out_dt = out_dt.to_local() local_expected_dt = expected_dt.to_local() self.assertEqual(out_dt.to_local(), expected_dt.to_local()) - if TEST_HPU or TEST_CUDA: + if torch.accelerator.is_available(): self.assertEqual( comm_mode.get_comm_counts()[ torch.ops._dtensor.shard_dim_alltoall @@ -880,6 +884,76 @@ def test_ordered_redistribute(self): ) self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) + @with_comms + def test_force_min_cost_redistribution_plan(self): + """ + Test that the disable_graph_based_transform context manager correctly controls + the redistribution algorithm selection (graph-based vs greedy). + """ + # Set deterministic seed for reproducible tensor generation + torch.manual_seed(21) + mesh = init_device_mesh(self.device_type, (2, 2, 2)) + input_data = torch.randn((8, 8, 8), device=self.device_type) + + # the redistribution path differs if we use graph-based or greedy search solution + src_placement, src_order = ( + [Shard(0), Shard(0), Shard(0)], # All mesh dims shard tensor dim 0 + ( + ShardOrderEntry(tensor_dim=0, mesh_dims=(0, 1, 2)), + ), # Device order: 0→1→2 + ) + dst_placement, dst_order = ( + [Shard(1), Shard(1), Shard(1)], # All mesh dims shard tensor dim 1 + ( + ShardOrderEntry(tensor_dim=1, mesh_dims=(0, 1, 2)), + ), # Device order: 0→1→2 + ) + + # Test both graph-based (enable_graph=True) and greedy (enable_graph=False) algorithms + for idx, enable_graph in enumerate([True, False]): + sharded_dt = _distribute_tensor( + input_data.clone(), mesh, src_placement, shard_order=src_order + ) + + with ( + use_min_cost_redistribution_plan(enabled=enable_graph), + DebugMode(record_torchfunction=False) as debug_mode, + ): + sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order) + trace_str = self._extract_redistribute_trace_from_debug_mode( + debug_mode.debug_string() + ) + + # Validate graph-based algorithm trace (idx=0, disable_graph=False) + # Graph-based uses optimal path search (Dijkstra's algorithm) + # Expected path has 6 transformations with strategic intermediate states + # Path: S(0)[0,1,2] → S(0)[0,1]S(2) → S(0)S(2)[1,0] → + # S(1)S(2)[1,0] → S(1)[0,1]S(2) → S(1)[0,1,2] + if idx == 0: + self.assertExpectedInline( + trace_str, + """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]S(2)->S(0)S(2)[1]S(2)[0]->S(1)S(2)[1]S(2)[0]->S(1)[0]S(1)[1]S(2)->S(1)[0]S(1)[1]S(1)[2]""", + ) + # Validate greedy algorithm trace (idx=1, disable_graph=True) + # Greedy uses simple heuristic approach (processes mesh dims sequentially) + # Expected path has 6 transformations but with different intermediate states + # Path: S(0)[0,1,2] → S(0)[0,1]R → S(0)RR → + # S(1)RR → S(1)[0,1]R → S(1)[0,1,2] + elif idx == 1: + self.assertExpectedInline( + trace_str, + """S(0)[0]S(0)[1]S(0)[2]->S(0)[0]S(0)[1]R->S(0)RR->S(1)RR->S(1)[0]S(1)[1]R->S(1)[0]S(1)[1]S(1)[2]""", + ) + expected_dt = _distribute_tensor( + input_data.clone(), mesh, dst_placement, shard_order=dst_order + ) + self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) + + # Clear the transformation cache between iterations. Without this, + # the second iteration would use cached paths from the first, + # causing the trace validation to fail because: + _gen_transform_infos.cache_clear() + @with_comms def test_generate_shard_orders(self): """Check if `generate_shard_orders` generates unique sharding combinations""" diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 80968fb52e904..a6cdc64a2e1dd 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -2,6 +2,7 @@ # Owner(s): ["oncall: distributed"] import itertools +import unittest import torch from torch.distributed.tensor import ( @@ -13,6 +14,7 @@ Replicate, Shard, ) +from torch.distributed.tensor._sharding_prop import ShardingPropagator from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests, skipIfRocm @@ -296,8 +298,8 @@ def test_zeros_like(self): self.assertEqual(dist_tensor.dtype, torch.float32) self.assertEqual(zeros_like_dt.dtype, torch.bfloat16) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_stack(self): mesh_2d = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) @@ -334,6 +336,34 @@ def test_stack(self): torch.stack([global_input, global_input], dim=1), ) + @with_comms + def test_stack_cache(self): + device_mesh = self.build_device_mesh() + + shape = (4, 8) + placements = [Replicate()] + dtensor_list = [] + for _ in range(3): + local_tensor = torch.randn(shape) + dt = DTensor.from_local(local_tensor, device_mesh, placements) + dtensor_list.append(dt) + + _ = torch.stack(dtensor_list) + + dtensor_list2 = [] + for _ in range(3): + local_tensor = torch.randn(shape) + dt = DTensor.from_local(local_tensor, device_mesh, placements) + dtensor_list2.append(dt) + + def error(*args, **kwargs): + raise AssertionError + + with unittest.mock.patch.object( + ShardingPropagator, "_propagate_tensor_meta_non_cached", error + ): + _ = torch.stack(dtensor_list2) + @with_comms def test_equal(self): device_mesh = self.build_device_mesh() diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 11b70c8554e52..5f3225d174cb2 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -16,7 +16,6 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._utils import ( _compute_local_shape_and_global_offset, - _explicit_order_placements, compute_global_tensor_info, compute_global_tensor_shape, compute_local_shape_and_global_offset, @@ -46,85 +45,6 @@ class LocalTest(TestCase): - def test_explicit_order_placements(self): - # mesh_shape: ShapeType, placements: Sequence[Placement] - test_cases = [ - { - "mesh_shape": [2, 4], - "placements": [Replicate(), Replicate()], - "ordered": [(0, Replicate()), (1, Replicate())], - }, - { - "mesh_shape": [3, 2], - "placements": [Shard(0), Replicate()], - "ordered": [(0, Shard(0)), (1, Replicate())], - }, - { - "mesh_shape": [2, 4], - "placements": [_StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(1, Shard(0)), (0, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], - "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=12), - _StridedShard(0, split_factor=4), - Shard(0), - ], - "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], - }, - ] - for test_case in test_cases: - actual = _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - expected = test_case["ordered"] - - self.assertEqual( - actual, - expected, - f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", - ) - - error_cases = [ - { - "mesh_shape": [2, 3, 4], - "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], - "exception_type": RuntimeError, - "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", - }, - { - "mesh_shape": [2, 3, 4], - "placements": [ - _StridedShard(0, split_factor=3), - Shard(0), - Shard(0), - ], - "exception_type": NotImplementedError, - "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", - }, - { - "mesh_shape": [2, 3], - "placements": [ - Shard(0), - ], - "exception_type": RuntimeError, - "exception_text": "Expected one placement per mesh dim", - }, - ] - for test_case in error_cases: - with self.assertRaisesRegex( - test_case["exception_type"], test_case["exception_text"] - ): - _explicit_order_placements( - test_case["mesh_shape"], test_case["placements"] - ) - def test_compute_local_shape_and_global_offset_uneven(self): # This case is not only 'uneven' bug also has an empty shard # (e.g. most DP ranks have local shape 18,4096, one has 8,4096, one has 0,4096 @@ -151,6 +71,225 @@ def test_compute_local_shape_and_global_offset_uneven(self): self.assertEqual(local_shape, (expected_shard_size, 4096)) self.assertEqual(global_offset, (expected_shard_offset, 0)) + # S, S uneven without empty + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 5 + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # S, S uneven with empty + global_shape = (13, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), Shard(0)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + + dp012_shard_size = 4 + if dp_rank in (0, 1, 2): + tp0_shard_size = 2 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = dp012_shard_size * dp_rank + tp0_shard_size + expected_shard_size = 2 + else: + assert dp_rank == 3 + tp0_shard_size = 1 + if tp_rank == 0: + expected_shard_offset = dp012_shard_size * dp_rank + expected_shard_size = 1 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, Shard + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [_StridedShard(0, split_factor=TP), Shard(0)] + TP_shard_size = int(global_shape[0] / TP) + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + expected_shard_size = 3 + expected_shard_offset = ( + tp_rank * TP_shard_size + expected_shard_size * dp_rank + ) + if dp_rank == 3: + expected_shard_size = 0 + expected_shard_offset = 18 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # SS, SS + global_shape = (39, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [ + _StridedShard(0, split_factor=3), + _StridedShard(0, split_factor=4), + ] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 8 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 4 * dp_rank + 2 + expected_shard_size = 4 + else: + assert dp_rank == 3 + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 4 * dp_rank + expected_shard_size = 3 + else: + assert tp_rank == 1 + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS) + global_shape = (18, 2) + DP = 4 + TP = 2 + mesh_shape = (DP, TP) + placements = [Shard(0), _StridedShard(0, split_factor=2)] + for my_coordinate in itertools.product(range(DP), range(TP)): + dp_rank, tp_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if dp_rank in (0, 1, 2): + tp0_shard_size = 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = tp0_shard_size + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 2 + expected_shard_size = 2 + else: + assert dp_rank == 3 + if tp_rank == 0: + expected_shard_offset = 5 * dp_rank + expected_shard_size = 2 + else: + assert tp_rank == 1 + expected_shard_offset = 5 * dp_rank + 1 + expected_shard_size = 1 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + + # (Shard, SS, Shard) + global_shape = (39, 2) + mesh0, mesh1, mesh2 = 4, 2, 3 + mesh_shape = (mesh0, mesh1, mesh2) + placements = [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] + for my_coordinate in itertools.product( + range(mesh0), range(mesh1), range(mesh2) + ): + mesh0_rank, mesh1_rank, mesh2_rank = my_coordinate + local_shape, global_offset = _compute_local_shape_and_global_offset( + global_shape, mesh_shape, list(my_coordinate), placements + ) + if mesh0_rank in (0, 1, 2): + if mesh1_rank == 0: + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 2 + expected_shard_size = 2 + else: + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 2 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 8 + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_size = 0 + expected_shard_offset = global_shape[0] + else: + assert mesh0_rank == 3 + if mesh1_rank == 0: + if mesh2_rank in (0, 1): + expected_shard_offset = 10 * mesh0_rank + 2 * mesh2_rank + expected_shard_size = 2 + else: + assert mesh2_rank == 2 + expected_shard_offset = 10 * mesh0_rank + 6 + expected_shard_size = 1 + else: + assert mesh1_rank == 1 + if mesh2_rank == 0: + expected_shard_offset = 10 * mesh0_rank + 3 + expected_shard_size = 2 + elif mesh2_rank == 1: + expected_shard_offset = 10 * mesh0_rank + 7 + expected_shard_size = 2 + else: + expected_shard_offset = global_shape[0] + expected_shard_size = 0 + self.assertEqual(local_shape, (expected_shard_size, 2)) + self.assertEqual(global_offset, (expected_shard_offset, 0)) + class UtilTest(DTensorTestBase): @property @@ -292,6 +431,78 @@ def test_compute_local_shape_and_global_offset_2D(self): global_tensor[dim0_start:dim0_end, dim1_start:dim1_end], ) + @with_comms + def test_compute_local_shape_and_global_offset_3D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 2 + mesh_size_1 = 2 + mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank = global_mesh.get_coordinate() + self.assertEqual(local_shape, [2, 2 * self.world_size]) + self.assertEqual( + global_offset, (4 * mesh0_rank + 8 * mesh1_rank + 2 * mesh2_rank, 0) + ) + + @with_comms + def test_compute_local_shape_and_global_offset_4D(self): + global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) + mesh_size_0 = 1 + mesh_size_1 = 2 + mesh_size_2 = 2 + mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) + global_mesh = init_device_mesh( + self.device_type, + (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), + mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(1), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1 * mesh_size_3, 2 * mesh_size_0 * mesh_size_2) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 4 * mesh0_rank, 8 * mesh3_rank + 4 * mesh1_rank), + ) + placements = [ + _StridedShard(0, split_factor=mesh_size_1), + _StridedShard(1, split_factor=mesh_size_3), + Shard(0), + Shard(0), + ] + local_shape, global_offset = compute_local_shape_and_global_offset( + global_tensor_shape, global_mesh, placements + ) + mesh0_rank, mesh1_rank, mesh2_rank, mesh3_rank = global_mesh.get_coordinate() + self.assertEqual( + local_shape, (2 * mesh_size_1, 2 * mesh_size_2 * mesh_size_3 * mesh_size_0) + ) + self.assertEqual( + global_offset, + (8 * mesh2_rank + 0 * mesh0_rank + 4 * mesh3_rank, 4 * mesh1_rank), + ) + @with_comms def test_fsdp_tp_meta_compute(self): # FSDP + TP sharding @@ -362,106 +573,6 @@ def test_hsdp_tp_meta_compute(self): self.assertEqual(local_shape, expected_local_shape) self.assertEqual(global_offset, expected_global_offset) - # TODO: remove this test once we support general meta compute on strided sharding - @with_comms - def test_strided_sharding_assumption_in_meta_compute(self): - # current ``compute_local_shape_and_global_offset`` does not allow Shard(i) - # placement to appear after the strided sharding part has ended. This test - # check that ``compute_local_shape_and_global_offset`` does not allow placements - # that violate the assumption and does not forbid the allowed ones. - - # Test 0: 2-D mesh - mesh_size_0 = 2 - mesh_size_1 = self.world_size // mesh_size_0 - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1), - mesh_dim_names=("mesh-0", "mesh-1"), - ) - global_tensor_shape = torch.Size([2 * self.world_size, 2 * self.world_size]) - - for shard_dim in [0, 1]: - placements = [ - _StridedShard(shard_dim, split_factor=mesh_size_1), - Shard(shard_dim), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 1: 3-D mesh - mesh_size_0 = 2 - mesh_size_1 = 2 - mesh_size_2 = self.world_size // (mesh_size_0 * mesh_size_1) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2"), - ) - - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(1), - ] - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - - # Test 2: 4-D mesh - mesh_size_0 = 1 - mesh_size_1 = 2 - mesh_size_2 = 2 - mesh_size_3 = self.world_size // (mesh_size_0 * mesh_size_1 * mesh_size_2) - global_mesh = init_device_mesh( - self.device_type, - (mesh_size_0, mesh_size_1, mesh_size_2, mesh_size_3), - mesh_dim_names=("mesh-0", "mesh-1", "mesh-2", "mesh-3"), - ) - # legal placements: Shard() appear after the strided part but it's on another - # tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(1), - ] - local_shape, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - expected_local_shape = ( - 2 * mesh_size_1 * mesh_size_3, - 2 * mesh_size_0 * mesh_size_2, - ) - self.assertEqual(local_shape, expected_local_shape) - - # illegal placements: Shard() appear after the strided part and it's on the - # same tensor dimension. - placements = [ - _StridedShard(0, split_factor=mesh_size_1), - _StridedShard(1, split_factor=mesh_size_3), - Shard(0), - Shard(0), - ] - with self.assertRaisesRegex(NotImplementedError, "the strided part has ended"): - _, _ = compute_local_shape_and_global_offset( - global_tensor_shape, global_mesh, placements - ) - class UtilSingleDeviceTest(TestCase): def test_compute_global_tensor_info_unsupported_placement(self): diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 426f77e379f8f..0e76da0dbe9c0 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -30,7 +30,7 @@ from torch.testing._internal.inductor_utils import HAS_GPU -def estimate_aten_runtime(fx_node, compute_multiplier=1.0): +def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0): # for tests, assume a matmul can hide a single collective if "c10" in str(fx_node.target): return 1.0 @@ -444,6 +444,62 @@ def func(a): self.assertTrue(same(out, correct)) self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0) + @torch._inductor.config.patch(get_patches()) + def test_custom_estimator_for_non_compute_nodes(self): + """Test that non-compute nodes with custom runtime estimates can trigger collective prefetching.""" + + def custom_estimator_with_relu(fx_node, override_size=None): + """Custom estimator that provides runtime for relu.""" + # Collective ops + if "c10" in str(fx_node.target): + return 1.0 + # Non-compute ops that we want to overlap + elif fx_node.target == aten.relu.default: + return 1.0 # relu has same time as collective + else: + return None + + def func(a, b): + c = torch.relu(a) + d = torch.mm(c, c) + + # Collective that is independent and should be prefetched during relu + ar = _functional_collectives.all_reduce(b, "sum", "0") + + # Use both results + return d * ar + + patches = { + **get_patches(), + "aten_distributed_optimizations.custom_runtime_estimation": custom_estimator_with_relu, + } + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs_a = ( + torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank + ) + inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2 + + with torch._inductor.config.patch(patches): + out, aten_graph_str = run_and_get_aten_graph( + torch.compile(func), inputs_a, inputs_b + ) + + # Verify that all_reduce is prefetched to run concurrently with relu + # The collective should start before relu completes to enable perfect overlap + FileCheck().check("all_reduce").check("relu").check("wait_tensor").run( + aten_graph_str + ) + + correct = func(inputs_a, inputs_b) + self.assertTrue(same(out, correct)) + self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0) + def get_bucket_patches(compute_multiplier=1.0): estimate_aten_runtime_part = functools.partial( @@ -1112,7 +1168,7 @@ def test_multiple_hiding_nodes_bucketing(self): # Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden def estimate_with_half_compute(fx_node, override_size=None): - return estimate_aten_runtime(fx_node, compute_multiplier=0.5) + return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5) def func(a, b, *, ranks): # Two all_gathers that will be hidden by multiple compute operations @@ -1162,6 +1218,56 @@ def func(a, b, *, ranks): correct = func(a, b, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_bucketing_with_convert_dtype(self): + """Test that all_gathers with dtype conversion get bucketed and produce correct results.""" + + def func(a, b, c, d, *, ranks): + # Convert inputs to float16 before all_gather + a_fp16 = a.to(torch.float16) + b_fp16 = b.to(torch.float16) + + # Two all_gathers with converted dtypes + ag1 = _functional_collectives.all_gather_tensor(a_fp16, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b_fp16, 0, ranks) + + # same dtype + ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + ag4 = _functional_collectives.all_gather_tensor(d, 0, ranks) + + return ag1, ag2, ag3, ag4 + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(4, 4, dtype=torch.float32, device=device_type) + b = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 2 + c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3 + d = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 4 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d) + + # Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together) + FileCheck().check_count( + "torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True + ).run(aten_graph_str) + + # Verify convert_element_type ops are removed (dtype conversion handled by _pre_bucket_all_gather) + FileCheck().check_not("torch.ops.prims.convert_element_type").run( + aten_graph_str + ) + + # Verify correctness - this tests that dtype conversion is handled correctly + correct = func(a, b, c, d, ranks=ranks) + self.assertTrue(same(out, correct)) + def get_toy_model(device_type: str): """ diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 2a1cb2b5580cb..0d11725829d26 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2073,16 +2073,16 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args): # ensure supported devices (cpu, cuda) succeeds during dispatch call tensor = torch.zeros(2, 2, device=torch.device(device)) # multi tensor collectives - if collective == dist.barrier: + if collective is dist.barrier: collective() elif collective in (dist.all_gather, dist.gather): collective([tensor], tensor, *args) - elif collective == dist.scatter: + elif collective is dist.scatter: collective(tensor, [tensor], *args) elif collective in (dist.reduce_scatter, dist.all_to_all): # gloo does not support reduce_scatter or all_to_all if backend != "gloo": - if collective == dist.reduce_scatter: + if collective is dist.reduce_scatter: collective(tensor, [tensor], *args) else: collective([tensor], [tensor], *args) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 512808757c40c..5b1b6c8925806 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -92,6 +92,9 @@ torch.version.cuda is not None or torch.version.hip is not None ) +CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( + torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 +) _start_time = time.time() _logger = logging.getLogger(__name__) @@ -345,7 +348,11 @@ def setUp(self): # These tests are expected to throw SIGABRT(6); # But if we are in Sandcastle, `skip_but_pass_in_sandcastle` would return 0. - TEST_NAN_ASSERT_RETURN = 0 if IS_SANDCASTLE else signal.SIGABRT + TEST_NAN_ASSERT_RETURN = ( + 0 + if (IS_SANDCASTLE and not (TEST_MULTIGPU and CUDA_12_AND_ABOVE)) + else signal.SIGABRT + ) self.special_return_code_checks = { self.test_nan_assert_float16.__wrapped__: TEST_NAN_ASSERT_RETURN, self.test_nan_assert_float32.__wrapped__: TEST_NAN_ASSERT_RETURN, @@ -537,10 +544,6 @@ def init_collective_task(t): # reset ENV os.environ["TORCH_NCCL_CUDA_EVENT_CACHE"] = "0" - CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( - torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 - ) - @requires_nccl() @skip_but_pass_in_sandcastle_if( # skip for cu126 as well due to https://github.com/pytorch/pytorch/issues/153479 diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py new file mode 100644 index 0000000000000..e1612d7639a13 --- /dev/null +++ b/test/distributed/test_debug.py @@ -0,0 +1,57 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +import torch +import torch.distributed as dist +from torch.distributed.debug import start_debug_server, stop_debug_server +from torch.testing._internal.common_utils import run_tests, TestCase + + +session = requests.Session() +retry_strategy = Retry(total=5, backoff_factor=0.5) +adapter = HTTPAdapter(max_retries=retry_strategy) +session.mount("http://", adapter) +session.mount("https://", adapter) + + +class TestDebug(TestCase): + def test_basics(self) -> None: + store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(store.port) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + port = 25999 + + def fetch(path: str) -> str: + resp = session.get(f"http://localhost:{port}{path}") + resp.raise_for_status() + return resp.text + + start_debug_server(port=port) + + self.assertIn("torch profiler", fetch("/")) + self.assertIn("View 0", fetch("/profile?duration=0.01")) + self.assertIn("test_basics", fetch("/stacks")) + self.assertIn("pg_status", fetch("/fr_trace")) + self.assertIn("Rank 0", fetch("/wait_counters")) + + if torch.cuda.is_available(): + self.assertIn("pg_status", fetch("/fr_trace_nccl")) + + # test errors + resp = session.get(f"http://localhost:{port}/blah") + self.assertEqual(resp.status_code, 404) + self.assertIn("Handler not found: /blah", resp.text) + + stop_debug_server() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index fa081243c2816..a4773a3f8da72 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -373,6 +373,117 @@ def test_all_gather_collective(self): self.assertEqual(tensor_list[1], different_tensors[1]) self.assertEqual(tensor_list[2], different_tensors[2]) + def test_reduce_scatter_tensor_collective(self): + """Test that reduce_scatter_tensor collective operation works correctly with LocalTensor.""" + # Create different tensors for each rank + different_tensors = { + 0: torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + 1: torch.tensor([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]), + 2: torch.tensor([[100.0, 200.0], [300.0, 400.0], [500.0, 600.0]]), + } + + fake_pg = torch.distributed.distributed_c10d._get_default_group() + + # Test reduce_scatter_tensor + with LocalTensorMode(self.world_size): + lt_reduce_scatter = LocalTensor(different_tensors) + lt_reduce_scatter_size = lt_reduce_scatter.size() + lt_output_tensor = torch.zeros( + lt_reduce_scatter_size[0] // fake_pg.size(), + *lt_reduce_scatter_size[1:], + dtype=lt_reduce_scatter.dtype, + device=lt_reduce_scatter.device, + ) + + dist.reduce_scatter_tensor( + lt_output_tensor, lt_reduce_scatter, group=fake_pg + ) + + expected_output = LocalTensor( + { + 0: torch.tensor([[111.0, 222.0]]), + 1: torch.tensor([[333.0, 444.0]]), + 2: torch.tensor([[555.0, 666.0]]), + } + ) + print(lt_output_tensor) + self.assertEqual(lt_output_tensor, expected_output) + + def test_all_gather_into_tensor_collective(self): + """Test that all_gather_into_tensor collective operation works correctly with LocalTensor.""" + # Create different tensors for each rank + different_tensors = { + 0: torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + 1: torch.tensor([[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]), + 2: torch.tensor([[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]]), + } + + fake_pg = torch.distributed.distributed_c10d._get_default_group() + + # Test all_gather_into_tensor + with LocalTensorMode(self.world_size): + lt_gather = LocalTensor(different_tensors) + lt_gather_size = lt_gather.size() + lt_output_tensor = torch.zeros( + lt_gather_size[0] * fake_pg.size(), + *lt_gather_size[1:], + dtype=lt_gather.dtype, + device=lt_gather.device, + ) + + dist.all_gather_into_tensor(lt_output_tensor, lt_gather, group=fake_pg) + + expected_output = torch.cat(list(different_tensors.values())) + + self.assertEqual(lt_output_tensor, expected_output) + + def test_all_to_all_single_collective(self): + """Test that all_to_all_single collective operation works correctly with LocalTensor.""" + from torch.distributed._functional_collectives import all_to_all_single + + # Create different tensors for each rank + # Each rank will split its tensor and send parts to other ranks + different_tensors = { + 0: torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + ), # rank 0 sends [0,0], [0,0], [0,0] to ranks 0,1,2 + 1: torch.tensor( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ), # rank 1 sends [1,1], [1,1], [1,1] to ranks 0,1,2 + 2: torch.tensor( + [2.0, 2.0, 2.0, 2.0, 2.0, 2.0] + ), # rank 2 sends [2,2], [2,2], [2,2] to ranks 0,1,2 + } + + # Each rank splits its input into 3 parts of size 2 each + input_split_sizes = [2, 2, 2] + # Each rank receives 3 parts of size 2 each from all ranks + output_split_sizes = [2, 2, 2] + + with LocalTensorMode(self.world_size): + lt_input = LocalTensor(different_tensors) + + # Test all_to_all_single using functional collectives API + result = all_to_all_single( + lt_input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=torch.distributed.distributed_c10d._get_default_group(), + ) + + result = result.wait() + # Verify result is a LocalTensor + self.assertIsInstance(result, LocalTensor) + + # After all_to_all_single: + # rank 0 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + # rank 1 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + # rank 2 receives: [0,0] from rank 0, [1,1] from rank 1, [2,2] from rank 2 = [0,0,1,1,2,2] + expected_output = torch.tensor([0.0, 0.0, 1.0, 1.0, 2.0, 2.0]) + + for rank in different_tensors: + self.assertEqual(result._local_tensors[rank], expected_output) + class TestLocalTensorWorld4(LocalTensorTestBase): world_size = 4 diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 3fec9a01f049c..ad30a7df5d43a 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -12,7 +12,6 @@ import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem from torch._inductor.runtime.triton_compat import triton from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem -from torch.testing._internal.common_cuda import SM100OrLater from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -265,10 +264,6 @@ def my_reduce_kernel( nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) -@skip_but_pass_in_sandcastle_if( - SM100OrLater, - "Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897", -) @instantiate_parametrized_tests class NVSHMEMTritonTest(MultiProcContinuousTest): def _init_device(self) -> None: diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index de6f2ba612977..c0c4c31cc1a81 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -667,6 +667,94 @@ def func(a, b): str(traced.graph) ) + def test_can_bucket_with_convert_dtype_as_hiding_nodes(self): + """ + Test that all_gathers can bucket when convert_element_type ops ARE the hiding nodes. + + Graph structure: + ag1_start -> convert1 (hides ag1) -> ag1_wait -> ag2_start -> convert2 (hides ag2) -> ag2_wait + + The convert_element_type ops ARE hiding nodes - no matmuls. + This tests that dependencies are transferred correctly when convert nodes are erased. + """ + + def func(a, b, c): + group_name = "0" + group_size = 1 + + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + b = torch.ops.prims.convert_element_type.default(b, torch.float16) + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, group_name + ) + ag3 = torch.ops._c10d_functional.all_gather_into_tensor( + c, group_size, group_name + ) + + mm = ag1_out @ ag1_out + + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + ag3_out = torch.ops._c10d_functional.wait_tensor(ag3) + + return ag1_out, ag2_out, ag3_out, mm + + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device, dtype=torch.float32) + b = torch.ones(4, 4, device=self.device, dtype=torch.float32) + c = torch.ones(4, 4, device=self.device, dtype=torch.float32) + + traced = make_fx(func)(a, b, c) + + # Find nodes + ag1, ag2, ag3 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + convert1 = traced.graph.find_nodes( + op="call_function", + target=torch.ops.prims.convert_element_type.default, + )[0] + mm = traced.graph.find_nodes( + op="call_function", + target=torch.ops.aten.mm.default, + )[0] + + hiding_annotations = { + ag1: convert1, + ag2: mm, + ag3: mm, + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + graph_str = str(traced.graph) + + f = FileCheck() + f.check_count("%all_gather_into_tensor", 1, exactly=True) + f.check("pre_bucket_all_gather").check("wait_tensor").check( + "%all_gather_into_tensor_out" + ).run(graph_str) + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 0d32a9e4917f5..768555efd1d4c 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -15,7 +15,7 @@ import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint -from functorch.compile import min_cut_rematerialization_partition +from functorch.compile import default_partition, min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -24,7 +24,7 @@ ) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu +from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -281,7 +281,14 @@ def runtime_wrapper(*runtime_args): run(export_compiler) - def test_tags_function(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -297,11 +304,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_via_global_checkpoint(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_via_global_checkpoint(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -316,17 +334,28 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_with_kwargs(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_with_kwargs(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + gn, torch.sin(x), y, use_reentrant=False ) x = torch.randn(4, 4, device=device, requires_grad=True) @@ -336,11 +365,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_sequential_layers(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_sequential_layers(self, device, partition_fn): def gn(x): x = x.cos() for _ in range(3): @@ -361,11 +401,22 @@ def fn(x): freqs=[2, 18], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_multiple_checkpoints(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_multiple_checkpoints(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -383,11 +434,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=6, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_module(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_module(self, device, partition_fn): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -411,11 +473,22 @@ def fn(x): bw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_decomps(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_decomps(self, device, partition_fn): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -443,6 +516,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + partition_fn=partition_fn, decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), @@ -702,7 +776,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_recompute(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): def context_fn_must_recompute_mm(): must_recompute_list = [ torch.ops.aten.mm.default, @@ -723,9 +804,9 @@ def context_fn_no_recompute_mm(): ), ) - def _test(context_fn, bw_compiler): + def _test(context_fn, bw_compiler, partition_fn): def gn(x): - return torch.sigmoid(torch.matmul(x, x)) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) def fn(x): return torch.utils.checkpoint.checkpoint( @@ -739,14 +820,14 @@ def fn(x): fw_compiler = functools.partial( count_ops, - freq=1, + freq=2, op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x) @@ -754,17 +835,19 @@ def fn(x): context_fn=context_fn_must_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) + freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) _test( context_fn=context_fn_no_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=2, # 2 bwd mm ops per fwd matmul + freq=4, # 2 bwd mm ops per fwd matmul op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) def test_sac_with_partial_context_fn(self): @@ -801,7 +884,16 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm( + self, device, partition_fn + ): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -841,15 +933,22 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( - self, device + self, device, partition_fn ): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -889,7 +988,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, disable_functionalization=True, ) self._validate(fn, backend, x, y) @@ -897,7 +996,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_triton_kernel(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): # Copy of the above test, but make sure that having a triton kernel in the # region does not error. def add_one(x): @@ -957,14 +1063,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_tensor_subclass(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1007,14 +1120,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_custom_rule(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): def _get_custom_policy(meta): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1072,14 +1192,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_partial_ctx_fn(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): def selective_checkpointing_context_fn(no_recompute_list): return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) @@ -1118,14 +1245,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_outplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1163,14 +1297,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_list_ops(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): def selective_checkpointing_context_fn(): # recompute everything no_recompute_list = [] @@ -1206,7 +1347,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1217,7 +1358,14 @@ def fn(x, y): "requires TorchDispatchMode + torch.compile work to complete" ) @requires_cuda_and_triton - def test_compile_selective_checkpoint_inplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1257,7 +1405,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1265,7 +1413,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) - def test_compile_selective_checkpoint_random_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_random_op(self, device, partition_fn): for preserve_rng_state in [True, False]: def selective_checkpointing_context_fn(): @@ -1312,7 +1467,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, @@ -1324,7 +1479,14 @@ def fn(x): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_invalid_context(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_invalid_context(self, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) * y @@ -1353,7 +1515,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) with self.assertRaisesRegex( Exception, "must generate a tuple of two `TorchDispatchMode`s" @@ -1362,7 +1524,14 @@ def fn(x, y): @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - def test_compile_selective_checkpoint_parametrization(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_parametrization(self, partition_fn): def sac_policy(): def _recomp_policy(): def _custom_policy(ctx, func, *args, **kwargs): @@ -1425,7 +1594,9 @@ def reset_parameters(self): bw_compiler = functools.partial( count_ops, freqs=[ - 2, # 1 from mul recompute, 1 from mul backward + # 1 from mul recompute, 1 from mul backward + # w/o CSE, we have one extra mul + 3 if partition_fn is default_partition else 2, 1, ], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], @@ -1434,7 +1605,7 @@ def reset_parameters(self): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) model = MLPModule() diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 326a1e627b3f4..f2a99dd18e2b1 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -2,6 +2,7 @@ # flake8: noqa: B950 import copy import math +import unittest from dataclasses import dataclass import torch @@ -1543,6 +1544,43 @@ def f(x, y): loss.backward() self.assertEqual(x + y, z) + @unittest.expectedFailure + def test_nonlocal_list_mutation_in_autograd_function(self): + """Test that nonlocal list mutation in autograd.Function forward is handled correctly.""" + + class SimpleAutogradFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x, z): + # Simple computation + o = torch.matmul(x, x) @ x + out = x.sin() + # Mutate the nonlocal list + z.append(out) + return torch.cos(torch.sin(o)), torch.sin(x) + + @staticmethod + def backward(ctx, grad_output1, grad_output2): + # Simple backward + return grad_output1 + grad_output2, None + + def fn(x): + z = [] + + outs = SimpleAutogradFunc.apply(x, z) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] + + return out1 + out2, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index c706e5f7af025..49f787bd25cd6 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1811,6 +1811,151 @@ def f3(x): """, ) + @make_logging_test(graph_breaks=True) + def test_try_block_with_graph_break_suppression(self, records): + global inner, middle_with_try, outer + + def inner(x): + result = x + 1 + torch._dynamo.graph_break() + return result + 1 + + def middle_with_try(x): + try: + return inner(x) + except Exception: + pass + return x + + def outer(x): + return middle_with_try(x) + + with torch._dynamo.config.patch(nested_graph_breaks=True, verbose=False): + torch.compile(outer, backend="eager")(torch.ones(3)) + + full_messages = [ + r for r in records if "Graph break in user code" in r.getMessage() + ] + suppressed_messages = [ + r + for r in records + if "user stack suppressed due to duplicate" in r.getMessage() + ] + + self.assertEqual( + len(full_messages), + 1, + f"Expected 1 full graph break message, got {len(full_messages)}", + ) + self.assertEqual( + len(suppressed_messages), + 1, + f"Expected at least 1 suppressed message, got {len(suppressed_messages)}", + ) + + self.assertExpectedInline( + munge_exc(full_messages[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_try_block_with_graph_break_suppression + torch.compile(outer, backend="eager")(torch.ones(3)) + File "test_error_messages.py", line N, in outer + return middle_with_try(x) + File "test_error_messages.py", line N, in middle_with_try + return inner(x) + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + + self.assertExpectedInline( + munge_exc( + suppressed_messages[0].getMessage(), suppress_suffix=True, skip=0 + ), + """\ +Graph break (user stack suppressed due to duplicate graph break) in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html""", + ) + + @make_logging_test(graph_breaks=True) + def test_nested_graph_break_different_call_sites_not_suppressed(self, records): + global inner, outer + + def inner(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + @torch.compile(backend="eager") + def outer(x): + x = inner(x + 4) + 8 + return inner(x) + 16 + + with torch._dynamo.config.patch(nested_graph_breaks=True, verbose=False): + outer(torch.ones(3)) + + self.assertEqual( + len(records), + 2, + f"Expected 2 graph break messages (one per call site), got {len(records)}", + ) + + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_nested_graph_break_different_call_sites_not_suppressed + outer(torch.ones(3)) + File "test_error_messages.py", line N, in outer + x = inner(x + 4) + 8 + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + + self.assertExpectedInline( + munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_nested_graph_break_different_call_sites_not_suppressed + outer(torch.ones(3)) + File "test_error_messages.py", line N, in outer + return inner(x) + 16 + File "test_error_messages.py", line N, in inner + torch._dynamo.graph_break() +""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index bac435cebfdfc..840d4b32ab389 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -153,7 +153,7 @@ def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2): return torch.cos(x * y) + c -class FunctionTests(torch._dynamo.test_case.TestCase): +class FunctionTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) @@ -4221,7 +4221,7 @@ def forward(self): return self.m() -class DefaultsTests(torch._dynamo.test_case.TestCase): +class DefaultsTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): def test_func_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg @@ -4749,7 +4749,7 @@ def fn(x, ys, zs): x = x.clone() for y, z in zip(ys, zs, strict=True): x += y * z - return x + return x, zip(ys, zs) opt_fn = torch.compile(fn, backend="eager") nopython_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -4758,7 +4758,11 @@ def fn(x, ys, zs): ys = [1.0, 2.0, 3.0] zs = [2.0, 5.0, 8.0] - self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs)) + ref = fn(x, ys, zs) + res = opt_fn(x, ys, zs) + self.assertEqual(ref[0], res[0]) + self.assertEqual(list(ref[1]), list(res[1])) + self.assertIsInstance(res[1], zip) # If nopython, should raise UserError with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index fc9284a3c9542..d0e712ffaa6cf 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -1154,7 +1154,7 @@ def install_subgraph(self, name, subgraph): splits = [ n for n in graph.nodes - if n.op == "call_function" and n.target == torch.split + if n.op == "call_function" and n.target is torch.split ] for split in splits: tracker.node_to_duplicates.pop(split) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 4e2a292fc69d4..21398490e7b03 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -4304,15 +4304,15 @@ def forward(self, L_x_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None + child: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad); child = None set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - sin: "f32[5]" = child.sin(); child = None + sin: "f32[5]" = _wrap_for_grad.sin(); _wrap_for_grad = None primals_out: "f32[]" = sin.sum(); sin = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None @@ -4352,24 +4352,24 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_2: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad) set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - child_1: "f32[5]" = child.sin() - child_2: "f32[5]" = child.cos(); child = None + child: "f32[5]" = _wrap_for_grad.sin() + child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) - _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) + _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, l_v_], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child, child_1], [child_2], [l_v_, l_v_], retain_graph = True, create_graph = True); child = child_1 = child_2 = l_v_ = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) """, @@ -4404,28 +4404,28 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + _wrap_for_grad: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_3: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child) + child_2: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(_wrap_for_grad) set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - child_1: "f32[5]" = child.sin() - child_2: "f32[5]" = child.cos(); child = None + child: "f32[5]" = _wrap_for_grad.sin() + child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) - value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) + _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - child_4: "f32[5]" = l_v_.sin() + child_3: "f32[5]" = l_v_.sin() - _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None + _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child, child_1], [child_2], [l_v_, child_3], retain_graph = True, create_graph = True); child = child_1 = child_2 = l_v_ = child_3 = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None - return (value, value_1, getitem) + return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) """, ) @@ -4458,18 +4458,18 @@ def forward(self, L_x_: "f32[5]"): _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None - child: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + aux: "f32[5]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None - child_1: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(child); child_1 = None + child: "f32[5]" = torch._functorch.eager_transforms._set_tensor_requires_grad(aux); child = None set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None - sin: "f32[5]" = child.sin() + sin: "f32[5]" = aux.sin() primals_out: "f32[]" = sin.sum(); sin = None - aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None + aux_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = aux_1 = None results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -5014,11 +5014,11 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): aux: "f32[3, 3, 3]" = child.cos() _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None - child_2: "f32[3, 3, 3]" = _autograd_grad[0] - child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None + getitem: "f32[3, 3, 3]" = _autograd_grad[0] + getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None @@ -5058,11 +5058,11 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): aux: "f32[3, 3, 3]" = child.cos() _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child, child_1], create_graph = True); child = child_1 = None - child_2: "f32[3, 3, 3]" = _autograd_grad[0] - child_3: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None + getitem: "f32[3, 3, 3]" = _autograd_grad[0] + getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 6348ba5638e05..781e95e0c7c95 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9650,6 +9650,36 @@ def fn(): self.assertEqual(fn_out, compiled_out) self.assertFalse(fn_out) + def test_constant_hasattr_returns_bool(self): + """Test that hasattr on constant values properly returns boolean ConstantVariable.""" + + # Test various constant types + def fn(): + # String constant + s = "hello" + result1 = hasattr(s, "upper") # True + result2 = hasattr(s, "nonexistent") # False + + # Integer constant + i = 42 + result3 = hasattr(i, "bit_length") # True + result4 = hasattr(i, "fake_method") # False + + # Float constant + f = 3.14 + result5 = hasattr(f, "is_integer") # True + result6 = hasattr(f, "missing_attr") # False + + # Use all results to ensure they're compiled + return (result1, result2, result3, result4, result5, result6) + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + fn_out = fn() + compiled_out = compiled_fn() + self.assertEqual(fn_out, compiled_out) + self.assertEqual(fn_out, (True, False, True, False, True, False)) + def test_torch_objects_as_keys(self): remap = {torch.float16: torch.float32} diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 10342f56d55d1..8eefbefe9237f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -968,6 +968,15 @@ class LRUCacheWarningTests(LoggingTestCase): @requires_cuda @make_logging_test(dynamo=logging.DEBUG) def test_lru_cache_warning_issued_during_tracing(self, records): + prev_default = torch._C._get_default_device() + + def _restore_default_device(): + if prev_default == "cpu": + torch.set_default_device(None) + else: + torch.set_default_device(prev_default) + + self.addCleanup(_restore_default_device) torch.set_default_device("cuda") @torch.compile(backend="eager") @@ -8184,6 +8193,130 @@ def fn(x): self.assertEqual(fn(torch.ones(3)), torch.ones(3) + 1) + def test_pytree_get_node_type_not_traced(self): + # Test that torch.utils._pytree._get_node_type is not traced into + # and doesn't cause excessive trace time overhead + from torch.utils._pytree import _get_node_type + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Call _get_node_type which is used internally by pytree operations + node_type = _get_node_type([x, y]) + assert node_type is list + # Do some work with pytree structures + data = {"a": x, "b": y} + flat, spec = pytree.tree_flatten(data) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_get_node_type_with_namedtuple(self): + # Test that torch.utils._pytree._get_node_type handles namedtuples correctly + # without being traced into, even when is_namedtuple_class is True + from collections import namedtuple + + from torch.utils._pytree import _get_node_type + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Create a namedtuple + point = Point(a, b) + # Call _get_node_type with a namedtuple instance + node_type = _get_node_type(point) + assert node_type is namedtuple + # Use pytree operations with namedtuples + flat, spec = pytree.tree_flatten(point) + result = flat[0] + flat[1] + return result + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_tree_is_leaf_not_traced(self): + # Test that torch.utils._pytree.tree_is_leaf is not traced into + # when is_leaf parameter is None (the common case) + from torch.utils._pytree import tree_is_leaf + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x, y): + # Test with various types + # Tensors are leaves + is_leaf_tensor = tree_is_leaf(x) + assert is_leaf_tensor is True + + # Lists are not leaves (they're in SUPPORTED_NODES) + is_leaf_list = tree_is_leaf([x, y]) + assert is_leaf_list is False + + # Dicts are not leaves + is_leaf_dict = tree_is_leaf({"a": x, "b": y}) + assert is_leaf_dict is False + + return x + y + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + + def test_pytree_tree_is_leaf_with_namedtuple(self): + # Test that torch.utils._pytree.tree_is_leaf handles namedtuples correctly + from collections import namedtuple + + from torch.utils._pytree import tree_is_leaf + + Point = namedtuple("Point", ["x", "y"]) + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(a, b): + # Namedtuples are not leaves (they're in SUPPORTED_NODES) + point = Point(a, b) + is_leaf_namedtuple = tree_is_leaf(point) + assert is_leaf_namedtuple is False + + # But individual tensors are leaves + is_leaf_tensor = tree_is_leaf(a) + assert is_leaf_tensor is True + + return a + b + + x = torch.randn(3, 4) + y = torch.randn(3, 4) + result = fn(x, y) + expected = x + y + + self.assertTrue(torch.allclose(result, expected)) + # Should compile successfully with fullgraph=True + self.assertEqual(cnt.frame_count, 1) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 3b4aff724eee4..7a40ae926a527 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -27,7 +27,7 @@ def remove_file_comment(gm_str: str) -> str: def print_graph(graph: torch.fx.GraphModule) -> str: - return remove_file_comment(graph.print_readable()) + return remove_file_comment(graph.print_readable(print_output=False)) class TestStreams(torch._dynamo.test_case.TestCase): @@ -585,6 +585,10 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): # Annotation: {'stream': 1} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None + wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None + # Annotation: {'stream': 0} add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add_3, add_2) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index e1e2b228062f6..33715d2cf861b 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -21,7 +21,7 @@ from torch._inductor.test_case import TestCase from torch._logging._internal import TorchLogsFormatter from torch.nn.parallel import DistributedDataParallel as DDP -from torch.testing._internal.common_utils import find_free_port, xfailIfS390X +from torch.testing._internal.common_utils import find_free_port from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -1017,7 +1017,6 @@ def fn(a): logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in chromium_events)) - @xfailIfS390X @requires_tlparse @torch._dynamo.config.patch("compiled_autograd", True) def test_compiled_autograd_attribution(self): diff --git a/test/dynamo/test_tree_map.py b/test/dynamo/test_tree_map.py new file mode 100644 index 0000000000000..0e18d69129d56 --- /dev/null +++ b/test/dynamo/test_tree_map.py @@ -0,0 +1,347 @@ +# Owner(s): ["module: dynamo"] + +import optree + +import torch +import torch._dynamo +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) +from torch.utils import _pytree as pytree + + +try: + import torch.utils._cxx_pytree as cxx_pytree +except ImportError: # pragma: no cover + cxx_pytree = None + + +def _tensor_leaf(*values): + first = values[0].clone() + for other in values[1:]: + first = first + other + return first + + +def _combine_leaves(*values): + first = values[0] + if isinstance(first, torch.Tensor): + return _tensor_leaf(*values) + if first is None: + return None + if isinstance(first, tuple): + # When tuples are marked as leaves, keep the structure from + # the leading tree so that specs remain aligned. + return first + total = first + for other in values[1:]: + total = total + other + return total + + +def _tuple_is_leaf(node): + return isinstance(node, tuple) + + +TREE_MAP_IMPLEMENTATIONS = [ + ("optree", optree.tree_map), + ("pytree_python", pytree.tree_map), +] +if cxx_pytree is not None: + TREE_MAP_IMPLEMENTATIONS.append(("pytree_cxx", cxx_pytree.tree_map)) + + +KWARG_CASES = [ + ("default", {}, None), + ("none_is_leaf", {"none_is_leaf": True}, {"optree"}), + ("is_leaf", {"is_leaf": _tuple_is_leaf}, None), + ("namespace", {"namespace": "torch"}, {"optree"}), + ( + "namespace_and_none_is_leaf", + {"namespace": "torch", "none_is_leaf": True}, + {"optree"}, + ), + ( + "namespace_none_is_leaf_predicate", + {"namespace": "torch", "none_is_leaf": True, "is_leaf": _tuple_is_leaf}, + {"optree"}, + ), +] + + +_NONE_IS_LEAF_UNSET = object() + + +def _build_tree(offset: int) -> dict[str, object]: + base = torch.arange(4, dtype=torch.float32).reshape(2, 2) + offset + nested = base + 5 + return { + "tensor": base, + "list": [ + base + 1, + { + "inner": base + 2, + "none": None, + }, + ], + "tuple": (3 + offset, (nested, None)), + "const_dict": {"leaf": base + 3}, + } + + +def _assert_trees_allclose(test_case: TestCase, ref, res) -> None: + ref_flat, ref_spec = pytree.tree_flatten(ref) + res_flat, res_spec = pytree.tree_flatten(res) + test_case.assertEqual(ref_spec, res_spec) + for expected, actual in zip(ref_flat, res_flat): + if isinstance(expected, torch.Tensor): + test_case.assertTrue(torch.allclose(expected, actual)) + else: + test_case.assertEqual(expected, actual) + + +@instantiate_parametrized_tests +class TreeMapCompileTests(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def _run_tree_map(self, tree_map_impl, kwargs): + lhs = _build_tree(0) + rhs = _build_tree(7) + + def fn(a, b): + return tree_map_impl(_combine_leaves, a, b, **kwargs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(lhs, rhs) + result = compiled(lhs, rhs) + _assert_trees_allclose(self, expected, result) + + @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize("kwargs_name,kwargs,allowed_impls", KWARG_CASES) + def test_tree_map_variants( + self, + tree_map_name: str, + tree_map_impl, + kwargs_name: str, + kwargs: dict, + allowed_impls, + ) -> None: + if tree_map_name == "pytree_cxx" and cxx_pytree is None: + self.skipTest("torch.utils._cxx_pytree is unavailable") + if allowed_impls is not None and tree_map_name not in allowed_impls: + self.skipTest("kwargs unsupported for implementation") + self._run_tree_map(tree_map_impl, kwargs) + + def test_tree_map_rejects_mismatched_container_types(self) -> None: + def fn(a, b): + return pytree.tree_map(lambda u, v: u + v, a, b) + + lhs = [torch.ones(2), torch.ones(2)] + rhs = (torch.ones(2), torch.ones(2)) + + with self.assertRaises(ValueError): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (ValueError, torch._dynamo.exc.Unsupported), + "Node type mismatch", + ): + compiled(lhs, rhs) + + def test_tree_map_is_leaf_handles_tensor_nodes(self) -> None: + def fn(tree): + return pytree.tree_map( + lambda pair: torch.stack(pair).sum(dim=0), + tree, + is_leaf=lambda node: isinstance(node, tuple), + ) + + tree = [(torch.ones(2), torch.ones(2) * 4)] + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_only_applies_to_tensor_nodes(self) -> None: + tree = {"tensor": torch.ones(2), "int": 3} + + def mapper(node): + if not isinstance(node, torch.Tensor): + raise AssertionError("mapper should only see tensors") + return node + 2 + + def fn(arg): + return pytree.tree_map_only(torch.Tensor, mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_only_multiple_trees_falls_back(self) -> None: + lhs = {"a": torch.ones(2), "b": torch.ones(2) * 2} + rhs = {"a": torch.ones(2) * 3, "b": torch.ones(2) * 4} + + def fn(a, b): + return pytree.tree_map_only(torch.Tensor, lambda x, y: x + y, a, b) + + with self.assertRaisesRegex(TypeError, "callable"): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (TypeError, torch._dynamo.exc.Unsupported), + r"(callable|Unsupported function call)", + ): + compiled(lhs, rhs) + + def test_tree_map_only_handles_multiple_types(self) -> None: + tree = {"int": 7, "tuple": (1, 2), "tensor": torch.ones(2)} + + def mapper(node): + if isinstance(node, int): + return node + 1 + if isinstance(node, tuple): + return tuple(val + 10 for val in node) + raise AssertionError("unexpected node passed to mapper") + + def fn(arg): + return pytree.tree_map_only((int, tuple), mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_is_leaf_non_constant_fallback(self) -> None: + tree = {"a": torch.arange(2.0), "b": torch.arange(2.0) + 1} + + def is_leaf(node): + if isinstance(node, torch.Tensor): + # Depends on runtime tensor value; cannot be folded to a constant. + return (node.sum() > 1).item() + return False + + def mapper(node): + return node * 2 if isinstance(node, torch.Tensor) else node + + def fn(arg): + return pytree.tree_map(mapper, arg, is_leaf=is_leaf) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_only_predicate_selector_skips_fastpath(self) -> None: + tree = {"keep": torch.ones(2), "other": (1, 2)} + + def selector(node): + return isinstance(node, torch.Tensor) and node.shape == (2,) + + def mapper(node): + return node + 5 if isinstance(node, torch.Tensor) else node + + def fn(arg): + return pytree.tree_map_only(selector, mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + _assert_trees_allclose(self, expected, result) + + def test_tree_map_none_nodes_reject_mismatched_siblings(self) -> None: + def fn(a, b): + return optree.tree_map(lambda u, v: (u, v), a, b) + + lhs = {"k": None} + rhs = {"k": torch.ones(2)} + + with self.assertRaisesRegex(ValueError, "Expected None"): + fn(lhs, rhs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + (ValueError, torch._dynamo.exc.Unsupported), + r"(Expected None|expected )", + ): + compiled(lhs, rhs) + + @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + def test_tree_map_none_nodes_default_behavior( + self, tree_map_name: str, tree_map_impl + ) -> None: + if tree_map_name == "optree": + self.skipTest("optree treats None as an internal node by default") + + def fn(a, b): + return tree_map_impl(lambda u, v: (u, v), a, b) + + tree = {"k": None} + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree, tree) + result = compiled(tree, tree) + + self.assertEqual(result["k"], (None, None)) + self.assertEqual(result, expected) + + def test_constantvariable_handles_none_is_leaf_kwarg(self) -> None: + tree = {"none": None} + + def run_case(none_is_leaf_flag): + def fn(arg): + def mapper(node): + if node is None: + return "visited" + return node + + kwargs = {} + if none_is_leaf_flag is not _NONE_IS_LEAF_UNSET: + kwargs["none_is_leaf"] = none_is_leaf_flag + return optree.tree_map(mapper, arg, **kwargs) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + self.assertEqual(result, expected) + return result["none"] + + self.assertEqual(run_case(_NONE_IS_LEAF_UNSET), None) + self.assertEqual(run_case(False), None) + self.assertEqual(run_case(True), "visited") + + def test_constantvariable_handles_python_and_dtype_leaves(self) -> None: + tree = { + "int": 7, + "nested": {"string": "foo", "dtype": torch.float32}, + } + + def fn(arg): + def mapper(node): + if isinstance(node, int): + return node + 1 + if isinstance(node, str): + return node.upper() + if isinstance(node, torch.dtype): + return torch.float64 + return node + + return optree.tree_map(mapper, arg) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(tree) + result = compiled(tree) + self.assertEqual(result["int"], 8) + self.assertEqual(result["nested"]["string"], "FOO") + self.assertIs(result["nested"]["dtype"], torch.float64) + self.assertEqual(result, expected) + + +if __name__ == "__main__": # pragma: no cover + run_tests() diff --git a/test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok b/test/dynamo_expected_failures/TestArrayCreationCopyArgument.test_striding_not_ok deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 5efbb13c25fd2..abfbb7a6004df 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -5,7 +5,7 @@ import unittest import warnings from dataclasses import dataclass -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple import torch import torch._dynamo @@ -17,11 +17,46 @@ from torch.export.graph_signature import OutputKind from torch.testing import FileCheck from torch.testing._internal.common_utils import TEST_CUDA +from torch.utils import _pytree as pytree GLOBAL_LIST = [] +class GlobalContext: + def __init__(self) -> None: + self._summaries: dict[str, MetricValue] = {} + self._tensors: dict[str, Tensor] = {} + + def __flatten__(self): + """Flattens into (leaves, ctx).""" + summary_leaves, summary_spec = pytree.tree_flatten(self._summaries) + tensor_leaves, tensor_spec = pytree.tree_flatten(self._tensors) + leaves = (*summary_leaves, *tensor_leaves) + ctx = (summary_spec, tensor_spec) + return leaves, ctx + + @classmethod + def __unflatten__(cls, leaves, ctx: tuple[pytree.TreeSpec, pytree.TreeSpec]): + """Reconstructs from (leaves, ctx).""" + output = cls() + summary_spec, tensor_spec = ctx + assert len(leaves) == summary_spec.num_leaves + tensor_spec.num_leaves + output._summaries = pytree.tree_unflatten( + leaves[: summary_spec.num_leaves], summary_spec + ) + output._tensors = pytree.tree_unflatten( + leaves[summary_spec.num_leaves :], tensor_spec + ) + return output + + def __enter__(self) -> "GlobalContext": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + pass + + @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): def test_joint_basic(self) -> None: @@ -582,6 +617,33 @@ def make_inputs(b: int): self.assertIsNotNone(gm.meta["tracing_context"].fake_mode) self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1) + def test_dynamo_graph_capture_ctx_return(self): + class Module(torch.nn.Module): + def forward(self, x): + with GlobalContext() as ctx: + z = x + 1 + ctx._tensors["6"] = x + 2 + return z, ctx + + def make_inputs(): + return (torch.randn(2, 3),) + + try: + pytree.register_pytree_node( + GlobalContext, + lambda x: x.__flatten__(), + GlobalContext.__unflatten__, + ) + mod = Module() + + gm = dynamo_graph_capture_for_export(mod)(*make_inputs()) + test_inputs = make_inputs() + actual_outputs = pytree.tree_leaves(gm(*test_inputs)) + expected_outputs = pytree.tree_leaves(mod(*test_inputs)) + self.assertEqual(actual_outputs, expected_outputs) + finally: + pytree._deregister_pytree_node(GlobalContext) + def test_dynamo_graph_capture_dict_keys_getitem(self): class Module(torch.nn.Module): def forward(self, x): @@ -614,9 +676,9 @@ def forward(self, args_0): _tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,)) L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1) l_args_0_ = L_args_0_ - add = l_args_0_ + 1; add = None + add = l_args_0_ + 1 mul = l_args_0_ * 2; l_args_0_ = None - return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul), self._out_spec)""", + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""", ) self.assertEqual(gm(*test_inputs), foo(*test_inputs)) @@ -652,7 +714,10 @@ def make_inputs(): return (torch.randn(2, 3),) trace_inputs = make_inputs() - with warnings.catch_warnings(record=True) as w: + with ( + torch._dynamo.config.patch(replay_side_effects=False), + warnings.catch_warnings(record=True) as w, + ): gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) cnt = 0 for entry in w: diff --git a/test/export/test_export.py b/test/export/test_export.py index 204d458e77704..6ebed4f224643 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -968,7 +968,7 @@ def forward(self, x): view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None sdpa_score0 = self.sdpa_score0 sdpa_mask0 = self.sdpa_mask0 - flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None + flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None getitem = flex_attention[0] getitem_1 = flex_attention[1]; getitem_1 = None getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None @@ -16142,8 +16142,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # expected 3*..., but got 8 ep.module()(torch.randn(4, 2)) - @testing.expectedFailureSerDer # T195866111 - @testing.expectedFailureSerDerNonStrict @testing.expectedFailureStrictV2 def test_hints_wrapper(self): strict = True @@ -16483,7 +16481,7 @@ def forward(self, x): # Expect builtin round in the export graph round_nodes = [ - n for n in ep.graph.nodes if n.op == "call_function" and n.target == round + n for n in ep.graph.nodes if n.op == "call_function" and n.target is round ] self.assertEqual(len(round_nodes), 1) diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 2ffdfec1e8633..2ac21e56c5c9c 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -356,7 +356,7 @@ def is_decorateinfo_skip_or_xfail(decorateinfo): actual_decorator = decorateinfo.decorators[0] if isinstance(actual_decorator, toleranceOverride): return False - if actual_decorator == unittest.expectedFailure: + if actual_decorator is unittest.expectedFailure: return True # Assume the rest are skips return True diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index 751a4c4d21859..2d2899e9ca297 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -2,6 +2,10 @@ from torch._functorch._activation_checkpointing.graph_info_provider import ( GraphInfoProvider, ) +from torch._functorch._activation_checkpointing.knapsack import ( + dp_knapsack, + dp_knapsack_sliding_hirschberg, +) from torch._functorch._activation_checkpointing.knapsack_evaluator import ( KnapsackEvaluator, ) @@ -326,5 +330,82 @@ def test_get_backward_memory_from_topologically_sorted_graph(self): self.assertEqual(result_item[1], expected_result_item[1]) +class TestActivationCheckpointingKnapsack(TestCase): + def setUp(self): + # (memory, runtime, max_memory, expected_runtime, expected_saved, expected_recomputable) + self.test_cases = [ + ([2, 3, 2, 4, 1], [1, 2, 1, 3, 2], 5, 5.0, [3, 4], [2, 1, 0]), + ([1, 1, 1], [1, 2, 3], 3, 6.0, [0, 1, 2], []), + ([10, 20, 30], [1, 2, 3], 5, 0.0, [], [2, 1, 0]), + ([1, 2, 3], [10, 20, 30], 1, 10.0, [0], [2, 1]), + ([1, 1, 1], [2, 2, 2], 2, 4.0, [0, 1], [2]), + ([0, 2, 3], [5, 2, 3], 5, 10.0, [0, 1, 2], []), + ([1, 2, 3], [0, 2, 3], 3, 3.0, [2], [0, 1]), + ([100, 200, 300], [1000, 2000, 3000], 500, 5000.0, [1, 2], [0]), + ([0.5, 1.5, 2.0], [1.0, 2.0, 3.0], 2.0, 3.0, [1, 0], [2]), + ([], [], 10, 0.0, [], []), + ([1, 2, 3], [1, 2, 3], 0, 0.0, [], [2, 1, 0]), + ([0, 0, 0], [1, 2, 3], 0, 6.0, [0, 1, 2], []), + ([1, 2, 3], [0, 0, 0], 6, 0.0, [], [2, 1, 0]), + ] + + def _run_knapsack_and_check( + self, + func, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ): + result_runtime, result_saved, result_recomputable = func( + memory, runtime, max_memory + ) + self.assertEqual(result_runtime, expected_runtime) + self.assertEqual(sorted(result_saved), sorted(expected_saved)) + self.assertEqual(sorted(result_recomputable), sorted(expected_recomputable)) + + def test_dp_knapsack(self): + for i, ( + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) in enumerate(self.test_cases): + with self.subTest(f"dp_knapsack_case_{i}"): + self._run_knapsack_and_check( + dp_knapsack, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) + + def test_dp_knapsack_sliding_hirschberg(self): + for i, ( + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) in enumerate(self.test_cases): + with self.subTest(f"dp_knapsack_sliding_hirschberg_case_{i}"): + self._run_knapsack_and_check( + dp_knapsack_sliding_hirschberg, + memory, + runtime, + max_memory, + expected_runtime, + expected_saved, + expected_recomputable, + ) + + if __name__ == "__main__": run_tests() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..c452f18e95d75 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2640,7 +2640,7 @@ def backward(ctx, grad_output): return grad_output * x, grad_output * x def f(a, b): - return FwBwMutation.apply(a, b) + return FwBwMutation.apply(a, b).sin_().clone() inps = [ torch.ones(3, 3, requires_grad=True), @@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None - return (mul, add)""", + clone = torch.ops.aten.clone.default(mul) + sin_ = torch.ops.aten.sin_.default(mul); mul = None + clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None + return (clone_1, add, clone)""", ) # important bit: there is 1 mutation in the bw self.assertExpectedInline( bw_graph[0].code.strip(), """\ -def forward(self, add, tangents_1): +def forward(self, add, clone, tangents_1): + cos = torch.ops.aten.cos.default(clone); clone = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None - mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None - return (mul_1, None)""", + mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None + return (mul_2, None)""", ) def test_fw_bw_mutation_no_functionalization2(self): diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index f83f059663149..bb228fab844fe 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4186,13 +4186,13 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs interleaved_5: "f32[3, 10, 2]" = torch.ops.aten.slice(interleaved_4, 0, 0, 3); interleaved_4 = None - child_17: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None - child_18: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None - child_19: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None + flip_3: "f32[3, 10, 2]" = interleaved_1.flip([0]); interleaved_1 = None + flip_4: "f32[3, 10, 2]" = interleaved_3.flip([0]); interleaved_3 = None + flip_5: "f32[3, 10, 2]" = interleaved_5.flip([0]); interleaved_5 = None - movedim_3: "f32[3, 10, 2]" = torch.movedim(child_17, 0, 0); child_17 = None - movedim_4: "f32[3, 10, 2]" = torch.movedim(child_18, 0, 0); child_18 = None - movedim_5: "f32[3, 10, 2]" = torch.movedim(child_19, 0, 0); child_19 = None + movedim_3: "f32[3, 10, 2]" = torch.movedim(flip_3, 0, 0); flip_3 = None + movedim_4: "f32[3, 10, 2]" = torch.movedim(flip_4, 0, 0); flip_4 = None + movedim_5: "f32[3, 10, 2]" = torch.movedim(flip_5, 0, 0); flip_5 = None return (movedim_3, movedim_4, movedim_5) """, # noqa: B950 ) @@ -8595,7 +8595,7 @@ def forward(self, L_t_: "f32[2, 3]"): getitem_13: "Sym(u20)" = while_loop[5] getitem_14: "Sym(u21)" = while_loop[6] - child: "f32[2, 3]" = while_loop[7]; while_loop = None + getitem_7: "f32[2, 3]" = while_loop[7]; while_loop = None add: "Sym(u15 + 1)" = getitem_8 + 1 add_1: "Sym(u16 + 1)" = getitem_9 + 1 @@ -8604,7 +8604,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_4: "Sym(u19 + 1)" = getitem_12 + 1 add_5: "Sym(u20 + 1)" = getitem_13 + 1 add_6: "Sym(u21 + 1)" = getitem_14 + 1 - add_7: "f32[2, 3]" = child + 1 + add_7: "f32[2, 3]" = getitem_7 + 1 add_8: "f32[2, 3]" = getitem_8 + l_t_; getitem_8 = None add_9: "f32[2, 3]" = getitem_9 + l_t_; getitem_9 = None @@ -8613,7 +8613,7 @@ def forward(self, L_t_: "f32[2, 3]"): add_12: "f32[2, 3]" = getitem_12 + l_t_; getitem_12 = None add_13: "f32[2, 3]" = getitem_13 + l_t_; getitem_13 = None add_14: "f32[2, 3]" = getitem_14 + l_t_; getitem_14 = None - add_15: "f32[2, 3]" = child + l_t_; child = l_t_ = None + add_15: "f32[2, 3]" = getitem_7 + l_t_; getitem_7 = l_t_ = None return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15) class cond_fn_0(torch.nn.Module): diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 0a5d03f9dd1f0..37bb013e5df82 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -978,23 +978,21 @@ def foo(t): fn = foo bdim = 0 for op in reversed(op_list): - if op == vmap: + if op is vmap: fn = op(fn, in_dims=bdim) bdim += 1 else: fn = op(fn) expected = f"{repr(x)}" - level = 0 - for op in op_list: - level += 1 # noqa: SIM113 - if op == grad: - expected = f"GradTrackingTensor(lvl={level}, value={expected})" - elif op == vmap: - bdim -= 1 + for level, op in enumerate(op_list): + if op is grad: expected = ( - f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})" + f"GradTrackingTensor(lvl={level + 1}, value={expected})" ) + elif op is vmap: + bdim -= 1 + expected = f"BatchedTensor(lvl={level + 1}, bdim={bdim}, value={expected})" fn(x) buf = buf.replace("\n", "").replace(" ", "") diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 0f893201733d3..ac58b81350cf4 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -6139,9 +6139,9 @@ def f(x): else: input = torch.randn(5) - if transform == vjp: + if transform is vjp: transform = functools.partial(transform, f) - elif transform == jvp: + elif transform is jvp: input = (input,) transform = functools.partial(transform, f, input) else: diff --git a/test/fx/test_fx_split.py b/test/fx/test_fx_split.py index 8d2b120e534ae..ae6880ab70e27 100644 --- a/test/fx/test_fx_split.py +++ b/test/fx/test_fx_split.py @@ -84,7 +84,7 @@ def forward(self, y): # Create custom operator support to mark wrapped_add as supported class CustomOpSupport(op_support.OperatorSupportBase): def is_node_supported(self, submodules, node) -> bool: - return node.target == wrapped_add + return node.target is wrapped_add # Create a simple splitter to test the edge case class TestSplitter(splitter_base._SplitterBase): diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index 0ee60f978127d..e887f90dba227 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -782,7 +782,7 @@ def replacement(a, b, bias): found_repalcement_node = False for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul: + if node.target is wrapped_gemm_bias_mul: found_repalcement_node = True break @@ -847,7 +847,7 @@ def gemm_bias_mul_replacement_with_c(a, b, bias, c): repalcement_node_found = 0 for node in traced.graph.nodes: - if node.target == wrapped_gemm_bias_mul_with_c: + if node.target is wrapped_gemm_bias_mul_with_c: repalcement_node_found += 1 self.assertEqual(repalcement_node_found, 2) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 329f20f81cdb5..00cb0e7b8b21a 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -910,6 +910,35 @@ def forward(self, a: "f32[8]", l_y_: "f32[8]"): """, ) + @unittest.expectedFailure + def test_nonlocal_list_mutation_hidden(self): + """Test that nonlocal list mutation inside nested_compile_region is handled correctly.""" + + @nested_compile_region + def gn(x, z): + o = torch.matmul(x, x) @ x + out = x.sin() + z.append(out) + return torch.cos(torch.sin(o)), torch.sin(x) + + def fn(x): + z = [] + + outs = gn(x, z) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] + + return out1 + out2, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + @inductor_config.patch("fx_graph_cache", False) def test_view_to_reshape(self): @nested_compile_region diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index a585f2055e89f..7b5f01d236e7f 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -911,8 +911,8 @@ def inputs_fn(): op="call_function", target=torch.ops.aten.mm.default ) self.assertEqual(len(mm_nodes), 4) - self.assertNotIn("partitioner_tag", mm_nodes[0].meta) - self.assertNotIn("partitioner_tag", mm_nodes[1].meta) + self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward") + self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward") self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) diff --git a/test/inductor/test_analysis.py b/test/inductor/test_analysis.py index a6946cb7b31a7..147760fe4df67 100644 --- a/test/inductor/test_analysis.py +++ b/test/inductor/test_analysis.py @@ -25,6 +25,7 @@ from torch.testing._internal.common_utils import ( parametrize, run_tests, + skipIfXpu, TEST_WITH_SLOW, TestCase, ) @@ -402,6 +403,9 @@ def verify_triton(comp): (not torch.xpu.is_available()) and (not SM80OrLater), "Requires XPU or CUDA SM80", ) + @skipIfXpu( + msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5491" + ) @skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU") @dtypes(torch.float, torch.float16) @parametrize( diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5f0447c32264e..fd962c8bea70a 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -67,12 +67,10 @@ IS_MACOS, IS_WINDOWS, MACOS_VERSION, - MI300_ARCH, parametrize, runOnRocm, skipIfMPS, skipIfRocm, - skipIfRocmArch, skipIfWindows, skipIfWindowsXPU, skipIfXpu, @@ -175,11 +173,8 @@ def get_module_ext_type(): class AOTInductorTestsTemplate: - # Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for - # AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners @common_utils.parametrize("embed_kernel_binary", [False, True]) @common_utils.parametrize("max_autotune", [False, True]) - @skipIfRocmArch(MI300_ARCH) def test_simple(self, embed_kernel_binary, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -671,6 +666,49 @@ def forward(self, x): code ) + @requires_gpu + def test_device_moved_constant(self): + # testing both directions + device_movements = [ + (torch.device(type=GPU_TYPE, index=0), torch.device("cpu")), + (torch.device("cpu"), torch.device(type=GPU_TYPE, index=0)), + ] + + class Model(torch.nn.Module): + def __init__(self, from_device): + super().__init__() + self.register_buffer("_buf", torch.randn(6, 7, device=from_device)) + self._param = torch.nn.Parameter( + torch.rand(6, 7, device=from_device), requires_grad=False + ) + + def forward(self, x): + to_device = x.device + moved_buf = self._buf.to(to_device) + moved_param = self._param.to(to_device) + return moved_buf, moved_param + + with config.patch( + { + "aot_inductor.use_runtime_constant_folding": False, + } + ): + for from_device, to_device in device_movements: + model = Model(from_device) + example_inputs = (torch.randn(6, 7, device=to_device),) + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + FileCheck().check_not("torch::aot_inductor::ConstantType::Unknown").run( + code + ) + FileCheck().check_count( + "torch::aot_inductor::ConstantType::Buffer", 2, exactly=True + ).run(code) + FileCheck().check_count( + "torch::aot_inductor::ConstantType::Parameter", 2, exactly=True + ).run(code) + def test_subclasses(self): device_to_init = self.device @@ -5169,10 +5207,7 @@ def forward(self, values, offsets): ) self.assertTrue(same(model(*example_input), actual)) - # Temporarily skipping test as pytorch/cpuinfo not able to retrieve cache size for - # AMD EPYC 9575F 64-Core Processor CPU in gfx942 VM Runners @common_utils.parametrize("max_autotune", [True, False]) - @skipIfRocmArch(MI300_ARCH) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 934f969543b2a..0203604b1ba03 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -4148,6 +4148,31 @@ def foo(x): "def triton_poi_fused_add_", 1, exactly=True ).run(code[0]) + @config.patch("graph_partition", True) + def test_graph_partition_user_defined_triton_kernel_reuse(self): + from torch.testing._internal.triton_utils import add_kernel + + def foo(x, y): + # partition 1 + output1 = torch.empty_like(x) + add_kernel[(4,)](x, y, output1, n_elements=128, BLOCK_SIZE=16) + output1_cpu = output1.cpu() + 1 + # partition 2 should reuse the user-defined kernel + x2 = output1_cpu.to("cuda") + output2 = torch.empty_like(x) + add_kernel[(4,)](x2, y, output2, n_elements=128, BLOCK_SIZE=16) + return output1, output2 + + compiled_foo = torch.compile(foo) + x = torch.randn(128, device="cuda") + y = torch.randn(128, device="cuda") + eager_out = foo(x, y) + compiled_out, code = run_and_get_code(compiled_foo, x, y) + self.assertEqual(eager_out, compiled_out) + FileCheck().check_count( + "async_compile.triton('add_kernel',", 1, exactly=True + ).run(code[0]) + def test_meta_tensor(self): def foobar(x, y): return x * 2, y * 3 diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index c148c69468902..3c50b4d881f8f 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -217,26 +217,38 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): ) def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): - """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" + """Create test inputs for decompose_k matrix multiplication. + Tensor a: Input matrix of shape (m, k) + Tensor b: Weight matrix of shape (k, n) + Tensor bias: Bias vector of shape (n,) + """ # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256 a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False) b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False) - return a, b + bias = ( + torch.randn(n, device=self.device, dtype=self.dtype, requires_grad=False) + * 0.1 + ) + return a, b, bias @skipIfXpu - def test_decompose_k_custom_op_autotune(self): - """Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale). - - Validates that the custom op encapsulates the entire fused operation with parametric - tuning for k_splits values controlling how the K dimension is decomposed. + def test_decompose_k_custom_op_autotune_dynamic_config_for_input_shape(self): + """Test decompose_k autotuning with with epilogue fusion(matmul+bias+relu+scale) and + dynamic config generation based on matmul input shapes. + + Validates that the custom op encapsulates the entire fused operation (matmul + bias + + relu + scale) with parametric tuning for k_splits values controlling how the K + dimension is decomposed. The config generator receives correct parameter names and + shapes, dynamically generates different k_split configs using get_k_splits for + different input shapes, and produces correct results matching the reference implementation. """ - test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}" + test_op_name = f"test_lib::matmul_relu_epilogue_dynamic_{id(self)}" def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: - """Matrix multiply with k-way decomposition - Python implementation.""" + """Matrix multiply with k-way decomposition.""" m = a.shape[0] n = b.shape[1] k = a.shape[1] @@ -254,7 +266,7 @@ def decompose_k_implementation( return torch.sum(result, dim=0) # [m, n] @torch.library.custom_op(test_op_name, mutates_args=()) - def matmul_relu_epilogue_op( + def matmul_relu_epilogue_dynamic_op( a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: """Matmul with decompose_k + bias + relu + scale (complete epilogue fusion).""" @@ -264,23 +276,28 @@ def matmul_relu_epilogue_op( scaled = activated * 2.0 return scaled - @matmul_relu_epilogue_op.register_fake + @matmul_relu_epilogue_dynamic_op.register_fake def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) - # Register autotuning with different k_splits values + # Define dynamic config generator using get_k_splits + def generate_k_split_configs( + fake_tensors: dict[str, torch.Tensor], + ) -> list[CustomOpConfig]: + """Generate k_split configs based on input matrix dimensions.""" + from torch._inductor.utils import get_k_splits + + m, k = fake_tensors["a"].shape[-2:] + _, n = fake_tensors["b"].shape[-2:] + + k_splits_list = get_k_splits(m, n, k) + + return [CustomOpConfig(k_splits=k) for k in k_splits_list] + register_custom_op_autotuning( - matmul_relu_epilogue_op, - configs=[ - CustomOpConfig(k_splits=2), - CustomOpConfig(k_splits=4), - CustomOpConfig(k_splits=8), - CustomOpConfig(k_splits=16), - CustomOpConfig(k_splits=32), - CustomOpConfig(k_splits=64), - CustomOpConfig(k_splits=128), - ], - name="matmul_relu_epilogue_autotuned", + matmul_relu_epilogue_dynamic_op, + config_generator=generate_k_split_configs, + name="matmul_relu_epilogue_dynamic_autotuned", input_gen_fns={ "a": lambda fake_tensor: torch.randn_like( fake_tensor, device=self.device @@ -297,38 +314,44 @@ def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): }, ) - # Create test inputs - a, b = self._create_decompose_k_inputs() - bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1 + # Test multiple shapes to verify dynamic config generation + test_shapes = [ + (256, 16384, 1024), + (256, 65536, 1024), + ] - # Compile the model using the custom op - @torch.compile - def test_model(a, b, bias): - return matmul_relu_epilogue_op(a, b, bias) + for m, k, n in test_shapes: + # Use helper function to create test inputs + a, b, bias = self._create_decompose_k_inputs(m, k, n) - torch._dynamo.reset() + @torch.compile + def test_model(a, b, bias): + return matmul_relu_epilogue_dynamic_op(a, b, bias) - with config.patch( - max_autotune=True, - benchmark_fusion=True, - ): - compiled_result = test_model(a, b, bias) + torch._dynamo.reset() - def reference_model(a, b, bias): - matmul_result = a @ b - biased = matmul_result + bias - activated = torch.relu(biased) - scaled = activated * 2.0 - return scaled + with config.patch( + max_autotune=True, + benchmark_fusion=True, + ): + compiled_result = test_model(a, b, bias) - expected = reference_model(a, b, bias) + def reference_model(a, b, bias): + matmul_result = a @ b + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled - torch.testing.assert_close( - compiled_result, - expected, - rtol=2e-1, - atol=5e-1, - ) + expected = reference_model(a, b, bias) + + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + msg=f"Failed for shape ({m}, {k}, {n})", + ) @skipIfXpu def test_multi_parameter_tuning(self): diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py index c26def3a54099..bd7221adc4065 100644 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -9,6 +9,7 @@ from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch from torch._inductor.test_case import run_tests, TestCase as InductorTestCase from torch._inductor.utils import ensure_cute_available +from torch.nn import functional as F from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -59,7 +60,7 @@ def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) + return F.grouped_mm(A_packed, B_batched, offs=offs) # Eager execution c_eager = grouped_gemm_fn(A, B, offsets) @@ -126,7 +127,7 @@ def test_grouped_gemm_assorted_layouts( assert B.stride(0) == 0 def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) + return F.grouped_mm(A_packed, B_batched, offs=offs) # --- eager --- c_eager = grouped_gemm_fn(A, B, offsets) diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index 7e79100f4c053..d7e4313f5fe3b 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import contextlib import os +import pathlib import subprocess import sys import tempfile @@ -23,6 +24,9 @@ ) +REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent + + @instantiate_parametrized_tests class DeterministicTest(TestCase): def setUp(self) -> None: @@ -121,9 +125,6 @@ def test_run2run_determinism(self, model_name, training_or_inference, precision) the current working directory. """ - if not os.path.exists("benchmarks/dynamo/huggingface.py"): - self.skipTest("Skip due to benchmarks/dynamo/huggingface.py not found.") - def _setup_env(env): env["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" # disable autotune cache env["TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE"] = "0" @@ -137,7 +138,7 @@ def _setup_env(env): with tempfile.TemporaryDirectory() as tmpdir: saved_pkl = os.path.join(tmpdir, "saved.pkl") cmd = ( - f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor" + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + f" --disable-cudagraphs --save-model-outputs-to={saved_pkl}" ) @@ -153,7 +154,7 @@ def _setup_env(env): # self.assertTrue("pass" in out.stdout.decode()) cmd = ( - f"{sys.executable} benchmarks/dynamo/huggingface.py --backend inductor" + f"{sys.executable} {REPO_ROOT}/benchmarks/dynamo/huggingface.py --backend inductor" + f" --{precision} --accuracy --only {model_name} --{training_or_inference}" + f" --disable-cudagraphs --compare-model-outputs-with={saved_pkl}" ) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 7a2f9ecdeae8b..c095243df7654 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from itertools import product from typing import Optional, TypeVar, Union -from unittest import expectedFailure, skip, skipUnless +from unittest import expectedFailure, mock, skip, skipUnless from unittest.mock import patch import torch @@ -28,6 +28,7 @@ from torch.nn.attention import SDPBackend from torch.nn.attention.experimental._paged_attention import PagedAttention from torch.nn.attention.flex_attention import ( + _apply_kernel_options, _create_empty_block_mask, _DEFAULT_SPARSE_BLOCK_SIZE, _identity, @@ -2277,6 +2278,51 @@ def test_shape(S, backend): test_shapes = [256, 255, 383, 384] _ = [test_shape(S, backend) for S in test_shapes] + @supported_platform + @skip_on_cpu + def test_mask_mod_handles_symint_addition(self, device): + dtype = torch.float16 + + def run(q, k, v): + ql = q.size(-2) + kl = k.size(-2) + frame = 32 + + def _opaque_mask(b, h, q_idx, kv_idx): + ref = ql // frame + mot = kl // frame + limit = (ref + mot) * frame + return q_idx < limit + + block_mask = create_block_mask( + _opaque_mask, + B=q.size(0), + H=q.size(1), + Q_LEN=ql, + KV_LEN=kl, + device=device, + ) + return flex_attention(q, k, v, block_mask=block_mask) + + compiled_run = torch.compile(run, fullgraph=True, dynamic=True) + + q = torch.randn(1, 2, 192, 32, device=device, dtype=dtype) + k = torch.randn(1, 2, 128, 32, device=device, dtype=dtype) + v = torch.randn(1, 2, 128, 32, device=device, dtype=dtype) + + eager_out = run(q, k, v) + compiled_out = compiled_run(q, k, v) + torch.testing.assert_close(eager_out, compiled_out, atol=1e-3, rtol=1e-3) + + # Exercise different dynamic shapes to ensure SymInt sums remain well-formed. + q2 = torch.randn(1, 2, 160, 32, device=device, dtype=dtype) + k2 = torch.randn(1, 2, 96, 32, device=device, dtype=dtype) + v2 = torch.randn(1, 2, 96, 32, device=device, dtype=dtype) + + eager_out2 = run(q2, k2, v2) + compiled_out2 = compiled_run(q2, k2, v2) + torch.testing.assert_close(eager_out2, compiled_out2, atol=1e-3, rtol=1e-3) + @supported_platform def test_multiple_score_mod_calls(self, device): query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device=device) @@ -3522,6 +3568,184 @@ def test_kernel_options_argument_is_respected(self, device): ) FileCheck().check("BLOCK_M : tl.constexpr = 16").run(code[0]) + @supported_platform + @skip_on_cpu + def test_backend_auto_matches_triton_large(self, device): + """BACKEND='AUTO' should follow Triton heuristics on large shapes.""" + make_tensor = functools.partial( + torch.randn, + (2, 2, 256, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + def compile_and_run(kernel_options): + return run_and_get_code( + torch.compile(flex_attention, fullgraph=True), + q, + k, + v, + kernel_options=kernel_options, + ) + + default_out, default_code = compile_and_run({"BACKEND": "AUTO"}) + triton_out, triton_code = compile_and_run({"BACKEND": "TRITON"}) + + torch.testing.assert_close(default_out, triton_out, atol=0.0, rtol=0.0) + + default_src = "\n".join(default_code) + FileCheck().check("flex_attention").check_not("flex_decoding").run(default_src) + + triton_src = "\n".join(triton_code) + FileCheck().check("flex_attention").check_not("flex_decoding").run(triton_src) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_matches_auto(self, device): + """BACKEND='TRITON_DECODE' should match heuristics on decode-friendly shapes.""" + make_tensor = functools.partial( + torch.randn, + (1, 2, 64, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + def compile_and_run(kernel_options): + return run_and_get_code( + torch.compile(flex_attention, fullgraph=True), + q, + k, + v, + kernel_options=kernel_options, + ) + + from torch._inductor.kernel.flex import flex_attention as flex_kernel_mod + + with mock.patch.object( + flex_kernel_mod, + "create_flex_decoding_kernel", + wraps=flex_kernel_mod.create_flex_decoding_kernel, + ) as decode_kernel: + default_out, _ = compile_and_run({"BACKEND": "AUTO"}) + self.assertTrue( + decode_kernel.called, + "Expected heuristics to dispatch to flex decoding kernel.", + ) + + with mock.patch.object( + flex_kernel_mod, + "create_flex_decoding_kernel", + wraps=flex_kernel_mod.create_flex_decoding_kernel, + ) as decode_kernel: + decode_out, _ = compile_and_run({"BACKEND": "TRITON_DECODE"}) + self.assertTrue( + decode_kernel.called, + "Expected explicit BACKEND='TRITON_DECODE' to use flex decoding kernel.", + ) + + self.assertEqual(decode_out.shape, (1, 2, 64, 64)) + torch.testing.assert_close(default_out, decode_out, atol=3e-3, rtol=3e-3) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_errors_when_not_supported(self, device): + """Requesting decode on unsupported shapes should raise a helpful error.""" + make_tensor = functools.partial( + torch.randn, + (1, 2, 256, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used", + ): + flex_compiled(q, k, v, kernel_options={"BACKEND": "TRITON_DECODE"}) + + @supported_platform + @skip_on_cpu + def test_backend_triton_decode_errors_with_non_power_of_two_gqa(self, device): + """BACKEND='TRITON_DECODE' should fail when GQA ratio is not a power of two.""" + q = torch.randn( + 1, 3, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + k = torch.randn( + 1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + v = torch.randn( + 1, 1, 64, 64, device=device, dtype=torch.float16, requires_grad=False + ) + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used", + ): + flex_compiled( + q, + k, + v, + enable_gqa=True, + kernel_options={"BACKEND": "TRITON_DECODE"}, + ) + + @supported_platform + @skip_on_cpu + def test_backend_rejects_legacy_force_use_flag(self, device): + """Combining BACKEND with FORCE_USE_FLEX_ATTENTION should raise an error.""" + make_tensor = functools.partial( + torch.randn, + (2, 2, 128, 64), + device=device, + dtype=torch.float16, + requires_grad=False, + ) + q, k, v = make_tensor(), make_tensor(), make_tensor() + + flex_compiled = torch.compile(flex_attention, fullgraph=True) + with self.assertRaisesRegex( + RuntimeError, + r"BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION", + ): + flex_compiled( + q, + k, + v, + kernel_options={ + "BACKEND": "TRITON", + "FORCE_USE_FLEX_ATTENTION": True, + }, + ) + + @supported_platform + def test_backend_defaults_and_rejects_invalid(self, device): + device = torch.device(device) + query = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + key = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + value = torch.randn(1, 1, 4, 8, device=device, dtype=torch.float32) + + kernel_options = _apply_kernel_options( + query, key, value, return_lse=True, kernel_options={} + ) + self.assertEqual(kernel_options["BACKEND"], "AUTO") + + with self.assertRaisesRegex(ValueError, r"Invalid BACKEND value 'INVALID'"): + _apply_kernel_options( + query, + key, + value, + return_lse=True, + kernel_options={"BACKEND": "INVALID"}, + ) + @supported_platform def test_block_mask_non_divisible(self, device): seq = torch.arange(1023, device=device) // 128 @@ -4154,7 +4378,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_ score_mod_0 = self.score_mod_0 mask_fn_0 = self.mask_fn_0 - flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None + flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None return (out,) @@ -4190,11 +4414,11 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"): - full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) fw_graph0 = self.fw_graph0 joint_graph0 = self.joint_graph0 mask_graph0 = self.mask_graph0 - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'BACKEND': 'AUTO', 'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0] getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -4214,7 +4438,7 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): - full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) + full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) return full_default """.replace( # noqa: B950 "GPU_TYPE", torch.device(device).type diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index 5f3735ac87e0d..50f12291a0e83 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -6,7 +6,11 @@ import torch from torch._inductor.kernel.flex.flex_flash_attention import ensure_flash_available from torch._inductor.test_case import TestCase as InductorTestCase -from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + create_block_mask, + flex_attention, +) from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_device_type import ( dtypes, @@ -105,6 +109,28 @@ def create_test_tensors( return q, k, v +def _create_block_mask_for_device( + mask_mod, batch_size, num_heads, q_len, kv_len, *, device +): + """Match FlexAttention's block-height expectations per compute capability.""" + q_block = _DEFAULT_SPARSE_BLOCK_SIZE + kv_block = _DEFAULT_SPARSE_BLOCK_SIZE + dev = torch.device(device) + if dev.type == "cuda": + major, _ = torch.cuda.get_device_capability(dev) + if major >= 10: + q_block *= 2 + return create_block_mask( + mask_mod, + batch_size, + num_heads, + q_len, + kv_len, + device=device, + BLOCK_SIZE=(q_block, kv_block), + ) + + @contextmanager def cuda_kernel_profiler(kernel_pattern="flash_attncute"): """Context manager for profiling CUDA kernels.""" @@ -139,7 +165,7 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): v, score_mod=score_mod, block_mask=block_mask, - kernel_options={"force_flash": True}, + kernel_options={"BACKEND": "FLASH"}, ) out_triton = compiled_fn( q, @@ -147,7 +173,7 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): v, score_mod=score_mod, block_mask=block_mask, - kernel_options={"force_flash": False}, + kernel_options={"BACKEND": "TRITON"}, ) assert out_flash.shape == out_ref_fp32.shape == out_triton.shape @@ -167,6 +193,36 @@ def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}" ) + needs_backward = any( + isinstance(t, torch.Tensor) and t.requires_grad for t in (q, k, v) + ) + if needs_backward: + grad = torch.randn_like(out_flash) + inputs = (q, k, v) + grads_ref = torch.autograd.grad(out_ref_fp32, inputs, grad) + grads_triton = torch.autograd.grad(out_triton, inputs, grad) + grads_flash = torch.autograd.grad(out_flash, inputs, grad) + + dq_atol = 2 * (grads_ref[0] + 0.3 - 0.3 - grads_ref[0]).abs().max().item() + dk_atol = 2 * (grads_ref[1] + 0.3 - 0.3 - grads_ref[1]).abs().max().item() + dv_atol = 2 * (grads_ref[2] + 0.3 - 0.3 - grads_ref[2]).abs().max().item() + + atol_pack = (dq_atol, dk_atol, dv_atol) + for grad_flash, grad_triton, grad_ref, atol in zip( + grads_flash, grads_triton, grads_ref, atol_pack + ): + assert torch.isfinite(grad_flash).all() + assert torch.isfinite(grad_triton).all() + assert torch.isfinite(grad_ref).all() + + triton_error = (grad_triton - grad_ref).abs().max().item() + flash_error = ( + (grad_flash - grad_ref.to(grad_flash.dtype)).abs().max().item() + ) + assert flash_error <= rtol * triton_error + atol, ( + f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {atol:.2e}" + ) + return out_flash, out_triton, out_ref_fp32 @@ -200,30 +256,28 @@ def test_flash_attention_unfriendly_seqlen_with_causal( @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_kernel_called(self, device, dtype): - """Test that flash attention kernel is actually called when force_flash=True.""" + """Test that flash attention kernel is actually called when BACKEND='FLASH'.""" q, k, v = create_test_tensors(dtype=dtype, device=device) compiled_fn = torch.compile(flex_attention) - # Test that flash kernel is called with force_flash=True + # Test that flash kernel is called with BACKEND='FLASH' with cuda_kernel_profiler("flash_attncute") as prof_result: - compiled_fn( - q, k, v, score_mod=_causal, kernel_options={"force_flash": True} - ) + compiled_fn(q, k, v, score_mod=_causal, kernel_options={"BACKEND": "FLASH"}) self.assertTrue( prof_result["found"], f"Flash attention kernel not found. Available kernels: {prof_result['kernel_names']}", ) - # Test that flash kernel is NOT called with force_flash=False + # Test that flash kernel is NOT called with BACKEND='TRITON' with cuda_kernel_profiler("flash_attncute") as prof_result: compiled_fn( - q, k, v, score_mod=_causal, kernel_options={"force_flash": False} + q, k, v, score_mod=_causal, kernel_options={"BACKEND": "TRITON"} ) self.assertFalse( prof_result["found"], - f"Flash attention kernel unexpectedly found when force_flash=False. Kernels: {prof_result['kernel_names']}", + f"Flash attention kernel unexpectedly found when BACKEND='TRITON'. Kernels: {prof_result['kernel_names']}", ) @dtypes(torch.float16, torch.bfloat16) @@ -284,8 +338,8 @@ def score_view_mod(score, b, h, q_idx, kv_idx): flash_vs_triton(q, k, v, score_mod=score_view_mod) @dtypes(torch.float16, torch.bfloat16) - def test_force_flash_error_with_requires_grad(self, device, dtype): - """Test that force_flash=True raises error when tensor requires gradients.""" + def test_flash_impl_error_with_requires_grad(self, device, dtype): + """Test that BACKEND='FLASH' raises error when tensor requires gradients.""" q, k, v = create_test_tensors(dtype=dtype, device=device) bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True) @@ -296,15 +350,108 @@ def score_mod_with_grad(score, b, h, q_idx, kv_idx): compiled_fn = torch.compile(flex_attention) with self.assertRaisesRegex( RuntimeError, - r"force_flash=True but flash attention cannot be used.*require gradients", + r"BACKEND='FLASH' but flash attention cannot be used.*require gradients", ): compiled_fn( q, k, v, score_mod=score_mod_with_grad, - kernel_options={"force_flash": True}, + kernel_options={"BACKEND": "FLASH"}, + ) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_mask_mod(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) + q.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support block_sparsity yet", + ): + compiled_fn( + q, k, v, block_mask=block_mask, kernel_options={"BACKEND": "FLASH"} + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_score_mod_capture(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + bias = torch.randn(4, device=device, dtype=dtype) + + def score_mod_with_capture(score, b, h, q_idx, kv_idx): + return score + bias[h] + + q.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support score_mods in bwds yet", + ): + compiled_fn( + q, + k, + v, + score_mod=score_mod_with_capture, + kernel_options={"BACKEND": "FLASH"}, + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_rejects_score_mod(self, device, dtype): + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def score_mod_twice(score, b, h, q_idx, kv_idx): + return score * 2 + + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + compiled_fn = torch.compile(flex_attention) + with self.assertRaisesRegex( + RuntimeError, + r"NYI: Flex Flash Attention doesn't support score_mods in bwds yet", + ): + compiled_fn( + q, + k, + v, + score_mod=score_mod_twice, + kernel_options={"BACKEND": "FLASH"}, + ).sum().backward() + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_backward_kernel_called(self, device, dtype): + q, k, v = create_test_tensors(dim=128, dtype=dtype, device=device) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + flash_vs_triton(q, k, v) + + compiled_fn = torch.compile(flex_attention) + + def run_for_profile(): + q_run, k_run, v_run = ( + t.detach().clone().requires_grad_(True) for t in (q, k, v) ) + compiled_fn( + q_run, k_run, v_run, kernel_options={"BACKEND": "FLASH"} + ).sum().backward() + + with cuda_kernel_profiler("flash_attncuteflash_bwd") as prof_result: + run_for_profile() + + self.assertTrue( + prof_result["found"], + f"Flash attention backward kernel not found. Kernels: {prof_result['kernel_names']}", + ) @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_with_block_mask(self, device, dtype): @@ -314,7 +461,9 @@ def test_flash_attention_with_block_mask(self, device, dtype): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -325,7 +474,9 @@ def test_flash_attention_block_mask_with_score_mod(self, device, dtype): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + causal_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -341,7 +492,9 @@ def custom_mask(b, h, q_idx, kv_idx): bias_value = mask_bias[h] return (q_idx >= kv_idx) | (bias_value > 0) - block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + custom_mask, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, block_mask=block_mask) @dtypes(torch.float16, torch.bfloat16) @@ -370,7 +523,7 @@ def document_mask(b, _h, q_idx, kv_idx): doc_id_kv = document_ids[b, kv_idx] return doc_id_q == doc_id_kv - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( document_mask, 2, 1, seq_len, seq_len, device=device ) flash_vs_triton(q, k, v, block_mask=block_mask) @@ -392,7 +545,7 @@ def mask_with_view_buffer(b, h, q_idx, kv_idx): double_bias = bias_value * 2 return (q_idx >= kv_idx) | (double_bias > 0) - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( mask_with_view_buffer, batch_size, num_heads, @@ -420,7 +573,7 @@ def dual_buffer_mask(b, h, q_idx, kv_idx): bias_cond = (head_term + batch_term).to(torch.float32) > 0 return causal | bias_cond - block_mask = create_block_mask( + block_mask = _create_block_mask_for_device( dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device ) flash_vs_triton(q, k, v, block_mask=block_mask) @@ -463,7 +616,9 @@ def mask_with_buffer(b, h, q_idx, kv_idx): bias_value = mask_bias[h] return (q_idx >= kv_idx) | (bias_value > 0) - block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device) + block_mask = _create_block_mask_for_device( + mask_with_buffer, 2, 4, 512, 512, device=device + ) flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 46a63db754697..90714b58951b1 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2478,6 +2478,66 @@ def layout_checker(choices): finally: clear_preprocessing_fns(clear_defaults=False) + @config.patch( + {"test_configs.max_mm_configs": 4, "max_autotune_gemm_backends": "TRITON"} + ) + def test_fixed_layout_at_lowering(self): + """ + Test that max-autotune with addmm/bmm/mm_plus_mm correctly handles + padding and maintains correct output strides. Specifically, when matrix + b with shape (4608, 1490) is padded, its stride should become 1536. + """ + + def mm_func(a, b) -> torch.Tensor: + a_t = torch.permute(a, [1, 0]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + # Add .to() to make sure that mm could be potentially padded + # Strides for output are not padded + return (a_t @ b_dtype).to(torch.float32) + + def addmm_func(a, b, bias) -> torch.Tensor: + a_t = torch.permute(a, [1, 0]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + bias_dtype = bias.to(torch.bfloat16) + return torch.addmm(bias_dtype, a_t, b_dtype).to(torch.float32) + + def bmm_func(a, b) -> torch.Tensor: + a_t = torch.permute(a, [2, 0, 1]).to(torch.bfloat16) + b_dtype = b.to(torch.bfloat16) + return torch.bmm(a_t, b_dtype).to(torch.float32) + + def mm_plus_mm_func(a1, b1, a2, b2) -> torch.Tensor: + a1_t = torch.permute(a1, [1, 0]).to(torch.bfloat16) + b1_dtype = b1.to(torch.bfloat16) + a2_t = torch.permute(a2, [1, 0]).to(torch.bfloat16) + b2_dtype = b2.to(torch.bfloat16) + return (a1_t @ b1_dtype + a2_t @ b2_dtype).to(torch.float32) + + a = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn((4608, 1490), device=GPU_TYPE) + bias = torch.randn(1490, device=GPU_TYPE) + + a_bmm = torch.randn((512, 4608, 8), device=GPU_TYPE, dtype=torch.bfloat16) + b_bmm = torch.randn((8, 4608, 1490), device=GPU_TYPE) + + # Test mm_plus_mm + a2 = torch.randn((4608, 512), device=GPU_TYPE, dtype=torch.bfloat16) + b2 = torch.randn((4608, 1490), device=GPU_TYPE) + + # 1490 padded to 1536, check in template code + output_code_padding_check = "stride_bk = 1536" + funcs_and_args = [ + (mm_func, (a, b)), + (addmm_func, (a, b, bias)), + (bmm_func, (a_bmm, b_bmm)), + (mm_plus_mm_func, (a, b, a2, b2)), + ] + + for f, args in funcs_and_args: + c_f = torch.compile(f, mode="max-autotune-no-cudagraphs") + _, code_out = run_and_get_code(c_f, *args) + FileCheck().check(output_code_padding_check).run(code_out[0]) + class TestMaxAutotunePrecompile(TestCase): def test_precompilation_threads(self): diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 1114810ceccdf..cae48673f2332 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -270,20 +270,11 @@ def f(x, y): ], ) @parametrize("split_reductions", (False, True)) - @parametrize( - "shape", ((1000000, 256), (32768, 2048), (32768, 768), (32768 + 1023, 768)) - ) + @parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768))) @parametrize("max_autotune", (False, True)) @parametrize("initial_xblock", (1, 2)) - @parametrize("add_1dim", (False, True)) def test_rms_norm_bwd( - self, - wdtype, - split_reductions, - shape, - max_autotune, - initial_xblock, - add_1dim, + self, wdtype, split_reductions, shape, max_autotune, initial_xblock ): # max_autotune can be slow and cost resource, trim down the tests # for max autotune @@ -296,9 +287,6 @@ def test_rms_norm_bwd( ): self.skipTest("Skip non-critical tests to save resources.") - if shape != (1000000, 256) and add_1dim: - self.skipTest("Skip non-critical tests to save resources.") - def f(x, w, eps): orig_dtype = x.dtype @@ -319,9 +307,6 @@ def fwd_bwd(f): # M, N = 1152 * 500, 384 M, N = shape x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True) - if add_1dim: - x = x[:, None, :] - w = torch.randn(N, dtype=wdtype, device=GPU_TYPE, requires_grad=True) dy = torch.randn_like(x) eps = 1e-5 @@ -397,6 +382,37 @@ def fwd_bwd(f): metrics.codegen_mix_order_reduction, ) + def test_layer_norm_bwd_with_dynamic_shape(self): + def f(x, w, eps): + return F.layer_norm(x, x.shape[-1:], weight=w, bias=None, eps=eps) + + def fwd_bwd(f): + x.grad = None + w.grad = None + out = f(x, w, eps) + out.backward(dy) + return x.grad, w.grad + + M0, M1, N = 251, 223, 128 + wbdtype = torch.float + xdtype = torch.float + x = torch.randn(M0, M1, N, dtype=xdtype, device=GPU_TYPE, requires_grad=True) + torch._dynamo.mark_dynamic(x, 0) + w = torch.randn(N, dtype=wbdtype, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + eps = 1e-5 + + opt_f = torch.compile(f) + + ref = fwd_bwd(f) + act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) + + self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + @parametrize("split_reductions", (False, True)) @parametrize("shape", ((32768, 768), (32769, 768))) def test_layer_norm_bwd_no_bias(self, split_reductions, shape): diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index c4aabd0375090..369013e1670b6 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -3,6 +3,7 @@ import re import sys import unittest +from unittest import mock import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools @@ -12,6 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS from torch.testing._internal.inductor_utils import HAS_PALLAS +from torch.utils._pallas import has_cuda_pallas, has_jax_tpu_backend from torch.utils._triton import has_triton @@ -746,7 +748,7 @@ def fn(x): self.assertEqual(result, expected) -@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +@unittest.skipUnless(has_cuda_pallas(), "requires jax and pallas") class PallasTestsCUDA(PallasTestsMixin, TestCase): DEVICE = "cuda" @@ -756,6 +758,28 @@ class PallasTestsCPU(PallasTestsMixin, TestCase): DEVICE = "cpu" +@unittest.skipUnless(has_jax_tpu_backend(), "requires JAX TPU backend") +@config.patch({"_debug_cpu_to_tpu_pallas": True}) +class PallasTestsTPU(PallasTestsMixin, TestCase): + DEVICE = "cpu" + + @mock.patch("torch._inductor.codegen.pallas.has_tpu_pallas", return_value=False) + def test_tpu_not_available_raises_error(self, mock_has_tpu_pallas): + def fn(a, b): + return a + b + + with self.assertRaisesRegex( + RuntimeError, + ( + "PALLAS_TARGET_TPU is set, but no TPU device was found. " + "Please make sure that you have a TPU available and that JAX is configured correctly." + ), + ): + torch.compile(fn, backend="inductor", options={"cpu_backend": "pallas"})( + torch.randn(16), torch.randn(16) + ) + + if test_torchinductor.HAS_CPU and HAS_PALLAS: make_pallas(test_torchinductor.SweepInputsCpuTest) # make_pallas(test_torchinductor.CpuTests) diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index 8a48bee86ba4e..5ad37c10b2c1a 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -278,7 +278,6 @@ def f(a, b): inp = (T(10, 10), T(10, 10)) self.assertExpectedInline(count_numel(f, *inp), """680""") - @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", @@ -300,7 +299,6 @@ def f(*inputs): inp = (T(10, 10) for _ in range(16)) self.assertExpectedInline(count_numel(f, *inp), """6400""") - @patch.object(config, "split_cat_fx_passes", False) @patch.object( config, "pre_grad_fusion_options", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f5d5c5107313f..3bc1dba12acd8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -127,13 +127,14 @@ ) from torch._inductor.utils import has_torchvision_roi_align from torch.testing._internal.common_utils import slowTest -from torch.testing._internal.inductor_utils import ( +from torch.testing._internal.inductor_utils import ( # noqa: F401 clone_preserve_strides_offset, GPU_TYPE, HAS_CPU, HAS_GPU, HAS_MPS, HAS_MULTIGPU, + HAS_TPU, IS_BIG_GPU, requires_gpu, RUN_CPU, @@ -2172,7 +2173,6 @@ def fn(a): @skipCPUIf(IS_MACOS, "fails on macos") @skip_if_halide # accuracy 4.7% off - @xfailIfS390X # accuracy failure def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) @@ -2484,7 +2484,7 @@ def fn(a, b_int8pack, b_scales, c): @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") @skipIfRocm @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") - def test__dyn_quant_pack_4bit_weight(self): + def test__dyn_quant_pack_4bit_weight_fp32(self): q_group = 32 k = 128 n = 128 @@ -2515,12 +2515,54 @@ def fn(b, in_features, out_features): self.common(fn, (b, in_features, out_features)) + @xfail_if_mps_unimplemented + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") + @skip_if_halide # bf16 + def test__dyn_quant_pack_4bit_weight_bf16(self): + k = 128 + n = 128 + q_group = 32 + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + torch.manual_seed(1) + b = torch.rand((k, n), dtype=torch.bfloat16) + in_features = b.size(0) + out_features = b.size(1) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(b, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + return b_int4pack + + self.common(fn, (b, in_features, out_features)) + @xfail_if_mps_unimplemented @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") @skipIfRocm @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") - def test__dyn_quant_matmul_4bit(self): + def test__dyn_quant_matmul_4bit_fp32_input(self): q_group = 32 m = 32 k = 128 @@ -2560,6 +2602,68 @@ def fn(a, q_group, in_features, out_features): self.common(fn, (a, q_group, in_features, out_features)) + @skipCPUIf(IS_MACOS, "fails on M1, mismatch in bf16 support reporting") + @xfail_if_mps_unimplemented + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") + @skip_if_halide # bf16 + def test__dyn_quant_matmul_4bit_bf16_input(self): + m = 32 + k = 128 + n = 128 + q_group = k + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + torch.manual_seed(1) + a = torch.rand((m, k), dtype=torch.bfloat16) + b = torch.rand((k, n), dtype=torch.bfloat16) + + # codegen_dynamic_shape test fails without explicitly marking these dynamic + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(b, 1) + + in_features = b.size(0) + out_features = b.size(1) + + if not self.is_dtype_supported(torch.bfloat16): + raise unittest.SkipTest( + f"torch.bfloat16 not supported for device {self.device}" + ) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(a, q_group, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + res = torch.ops.aten._dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + return res + + self.common(fn, (a, q_group, in_features, out_features), atol=1, rtol=0.5) + def test_expanded_reduction(self): def fn(x, y): z = x * y @@ -5424,6 +5528,32 @@ def fn(x): check_lowp=not is_halide_backend(self.device), # misaligned addr fp16 ) + def test_lp_pool1d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool1d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool1d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8),), + ) + + def test_lp_pool2d_with_inf_norm(self): + # https://github.com/pytorch/pytorch/issues/167197 + # Test that LPPool2d works with infinity norm (should behave like max pooling) + def fn(x): + return torch.nn.functional.lp_pool2d( + x, norm_type=float("inf"), kernel_size=2, stride=2 + ) + + self.common( + fn, + (torch.randn(3, 4, 8, 8),), + ) + @tf32_on_and_off(0.006) @skip_if_gpu_halide # slow def test_alexnet_prefix(self): @@ -6203,6 +6333,15 @@ def fn(x): x = torch.randn([16, 16], device=self.device) self.assertEqual(cfn(x), fn(x)) + def test_pow_infinite(self): + def fn(a, b): + return torch.pow(a, b) + + opt = torch.compile(fn, backend="inductor") + a = torch.randn((3, 4, 8), device=self.device) + b = float("inf") + self.assertTrue(same(opt(a, b), fn(a, b))) + def test_glu(self): def fn(x): return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) @@ -14629,21 +14768,6 @@ def test_weight_norm_conv2d(self): self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3)) - @skipIfMPS - def test_inner_reduction_detection(self): - if self.device == "cpu": - self.skipTest("Skip for CPU device") - - x = torch.randn(100000, 1, 256, device=self.device) - - @torch.compile - def f(x): - return x.sum(dim=(0, 1)) - - code = run_and_get_triton_code(f, x) - self.assertTrue("ReductionHint.OUTER" in code) - self.assertFalse("ReductionHint.INNER" in code) - @skip_if_halide @requires_cuda_and_triton @skip_if_cpp_wrapper("skip cpp wrapper") @@ -15077,7 +15201,7 @@ def forward( def test_grouped_mm(self): @torch.compile(fullgraph=True) def f(a, b, offs, out_dtype): - return torch._grouped_mm( + return F.grouped_mm( a, b.transpose(-2, -1), offs=offs, out_dtype=out_dtype ) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 1c9b39a1bd08d..d1b62feed3b41 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -828,9 +828,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, @@ -948,9 +945,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 7a9edd5570f3e..d70375ebc3345 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -246,9 +246,9 @@ def foo(x, y): ) def test_pointwise( self, - full_size: tuple[int], - view_size: tuple[int], - stride: Optional[tuple[int]], + full_size: tuple[int, ...], + view_size: tuple[int, ...], + stride: Optional[tuple[int, ...]], offset: Optional[int], require_block_ptr: bool, prefer_nd_tiling: bool, @@ -298,7 +298,7 @@ def get_input() -> torch.Tensor: ], ) def test_broadcast( - self, x_size: tuple[int], y_size: tuple[int], prefer_nd_tiling: bool + self, x_size: tuple[int, ...], y_size: tuple[int, ...], prefer_nd_tiling: bool ): """ Test that we can generate strided block pointers when inputs have different @@ -415,7 +415,7 @@ def load_args(reader): ((5, 6, 1, 1), (5, 6, 4, 3)), ], ) - def test_expand_broadcast(self, x_size: tuple[int], y_size: tuple[int]): + def test_expand_broadcast(self, x_size: tuple[int, ...], y_size: tuple[int, ...]): """ When the load and store have different shapes, we should use broadcast. """ @@ -423,7 +423,7 @@ def test_expand_broadcast(self, x_size: tuple[int], y_size: tuple[int]): def foo(x, y_size): return x.expand(y_size).clone() - def get_input(size: tuple[int]) -> torch.Tensor: + def get_input(size: tuple[int, ...]) -> torch.Tensor: device = torch.device(self.device) full = torch.randn(size).to(device) view = torch.as_strided(full, size, full.stride()) @@ -522,7 +522,7 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): ) def test_reduction( self, - view_size: tuple[int], + view_size: tuple[int, ...], num_block_pointers: int, num_triton_kernels: int, prefer_nd_tiling: bool, @@ -574,7 +574,10 @@ def test_reduction( ], ) def test_mixed_pointwise_reduction( - self, view_size: tuple[int], num_block_pointers: int, num_triton_kernels: int + self, + view_size: tuple[int, ...], + num_block_pointers: int, + num_triton_kernels: int, ): """ Tests mixing pointwise with reduction ops. @@ -744,8 +747,8 @@ def foo(x): ) def test_nd_tiling_odd_shapes_pointwise( self, - full_size: tuple[int], - view_size: tuple[int], + full_size: tuple[int, ...], + view_size: tuple[int, ...], num_block_pointers: int, num_tiles: int, ): @@ -794,7 +797,7 @@ def get_input() -> torch.Tensor: ) def test_2d_reduction_odd_shapes( self, - view_size: tuple[int], + view_size: tuple[int, ...], num_block_pointers: int, num_triton_kernels: int, reduction_op: Callable, @@ -829,7 +832,7 @@ def test_2d_reduction_odd_shapes( ) def test_2d_welford_reduction( self, - size: tuple[int], + size: tuple[int, ...], expected_num_block_pointers: int, expected_num_triton_kernels: int, expect_fallback: bool, diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index eee4dba7f2772..1f205eecec1bf 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -1216,6 +1216,20 @@ def f(x, y): compiled_out = torch.compile(f)(x, y) self.assertEqual(compiled_out, eager_out) + @requires_gpu + def test_triton_kernel_to_cpu(self): + def f(x, y): + out = torch.zeros_like(x) + add_kernel[(1,)](x, y, out, 16, 16) + out_cpu = out.cpu() + 1 + return out_cpu + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + eager_out = f(x, y) + compiled_out = torch.compile(f)(x, y) + self.assertEqual(compiled_out, eager_out) + @requires_gpu def test_triton_kernel_out_of_order(self): @triton.jit diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index e1411f3101e22..ae0f137e152ce 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -285,6 +285,14 @@ def test_embeddingbag_include_last_offset(self): self.assertEqual(ref_out, out) self.assertEqual(ref_out, out2) + def test_embeddingbag_2d_include_last_offset(self): + # Test case from https://github.com/pytorch/pytorch/issues/167974 + embedding_sum = torch.nn.EmbeddingBag(10, 3, include_last_offset=True) + input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]], dtype=torch.long) + res = embedding_sum(input) + # Check if number of bags matches + self.assertTrue(res.shape[0] == input.shape[0]) + class TestEmbeddingNNDeviceType(NNTestCase): def test_embedding_dense_grad(self, device): diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index dbd5d89ad6a61..26c0ab42905de 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -2,6 +2,7 @@ import json import os +import sys import tempfile import unittest from typing import Any @@ -364,6 +365,9 @@ def test_execution_trace_env_disabled(self, device): self.assertTrue(p.execution_trace_observer is None) @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", @@ -419,6 +423,9 @@ def fn(a, b, c): assert found_call_compiled_fx_graph @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") + @unittest.skipIf( + sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" + ) @unittest.skipIf( (not has_triton()) or (not TEST_CUDA and not TEST_XPU), "need triton and device(CUDA or XPU) availability to run", diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index e8d28d7eff032..3c5ef2aeeb83c 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -27,13 +27,16 @@ PRUNE_ALL = 1 KEEP_ELLIPSES = 2 KEEP_NAME_AND_ELLIPSES = 3 +IGNORE = 4 PRUNE_FUNCTIONS = { "torch/utils/_pytree.py(...): tree_map": KEEP_NAME_AND_ELLIPSES, "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES, + "": PRUNE_ALL, "": PRUNE_ALL, + "": IGNORE, "cudaStreamIsCapturing": PRUNE_ALL, # These show up only on CUDA, prune them so the CUDA and CPU expected results can be the same "cudaGetDeviceCount": PRUNE_ALL, @@ -117,6 +120,8 @@ def flatten(nodes, depth=0, out=None): if prune_level is None: out.append((depth, name)) flatten(node.children, depth + 1, out) + elif prune_level == IGNORE: + flatten(node.children, depth, out) elif prune_level == KEEP_NAME_AND_ELLIPSES: out.append((depth, name)) if node.children: @@ -720,10 +725,9 @@ def test_profiler_experimental_tree_with_stack_and_torch_function(self): test_profiler_tree.py(...): __torch_function__ torch/_tensor.py(...): __torch_function__ - - torch/_tensor.py(...): - - torch/_tensor.py(...): + torch/_tensor.py(...): + + torch/_tensor.py(...): aten::add torch/_tensor.py(...): _convert diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 75e4ebffbdf42..ed3eabd702690 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -866,7 +866,7 @@ def _test_match_activations( ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() - if prepare_fn == prepare_fx: + if prepare_fn is prepare_fx: m.eval() else: m.train() @@ -929,7 +929,7 @@ def _test_match_shadow_activations( ): if qconfig_dict is None: qconfig_dict = torch.ao.quantization.get_default_qconfig_mapping() - if prepare_fn == prepare_fx: + if prepare_fn is prepare_fx: m.eval() else: m.train() @@ -1082,7 +1082,7 @@ def _test_match_activations_mod_impl(self, prepare_fn=prepare_fx): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, @@ -1103,7 +1103,7 @@ def test_match_activations_mod_qat(self): def _test_match_activations_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional().eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} expected_occurrence = { ns.call_module(OutputLogger): 2, @@ -1165,7 +1165,7 @@ def _test_add_shadow_loggers_mod_impl(self, prepare_fn=prepare_fx): nn.Conv2d(1, 1, 1), ).eval() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(1, 1, 4, 4),), results_len=2, @@ -1182,7 +1182,7 @@ def test_add_shadow_loggers_mod_qat(self): def _test_add_shadow_loggers_fun_impl(self, prepare_fn=prepare_fx): m = LinearReluLinearFunctional() qconfig_dict = None - if prepare_fn == prepare_qat_fx: + if prepare_fn is prepare_qat_fx: qconfig_dict = {'': torch.ao.quantization.get_default_qat_qconfig('fbgemm')} res = self._test_match_shadow_activations( m, (torch.randn(4, 4),), results_len=2, prepare_fn=prepare_fn, diff --git a/test/test_autograd.py b/test/test_autograd.py index 5960ac8add36d..bc6967cdfb038 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -403,6 +403,18 @@ def backward(ctx, g0, g1): out = Func.apply(a)[0] out.backward() + def test_unused_grad_requires_grad_with_materialize(self): + x = torch.ones(10, requires_grad=True) + y = torch.ones(10, requires_grad=True) + z = (x**2).sum() + + g = torch.autograd.grad( + z, (x, y), allow_unused=True, materialize_grads=True, create_graph=False + ) + + self.assertFalse(g[0].requires_grad) + self.assertFalse(g[1].requires_grad) + def test_legacy_function_deprecation_exception(self): # Trigger exception class MyFunction(Function): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 2b5606aec98d6..d448f95319416 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -2756,6 +2756,25 @@ def test_fmod_remainder_by_zero_integral(self, device, dtype): value = 255 if dtype == torch.uint8 else -1 self.assertTrue(torch.all(fn(x, zero) == value)) + @onlyNativeDeviceTypes + @dtypes(*integral_types()) + def test_fmod_remainder_overflow(self, device, dtype): + fn_list = (torch.fmod, torch.remainder) + for fn in fn_list: + if dtype in [torch.uint8, torch.uint16, torch.uint32, torch.uint64]: + continue + + min_val = torch.iinfo(dtype).min + dividend = torch.full((2, 3), min_val, dtype=dtype, device=device) + divisor = torch.full((3,), -1, dtype=dtype, device=device) + + result = fn(dividend, divisor) + expected = torch.zeros_like(dividend) + self.assertEqual(result, expected) + + result_scalar = fn(dividend, -1) + self.assertEqual(result_scalar, expected) + @dtypes(*all_types_and(torch.half)) def test_fmod_remainder(self, device, dtype): # Use numpy as reference @@ -3058,6 +3077,18 @@ def test_floor_divide_zero(self, device, dtype): with self.assertWarnsOnceRegex(UserWarning, "floor_divide"): a // b + @dtypes(torch.int8, torch.int16, torch.int32, torch.int64) + def test_floor_divide_int_min(self, device, dtype): + int_min = torch.iinfo(dtype).min + a = torch.tensor([int_min], dtype=dtype, device=device) + b = torch.tensor([-1], dtype=dtype, device=device) + + result = torch.floor_divide(a, b) + result_ = a // b + + self.assertEqual(result, a) + self.assertEqual(result_, a) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_muldiv_scalar(self, device, dtype): x = make_tensor((10, 3), dtype=dtype, device=device, low=None, high=None) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 41ce5af6a28be..5c721395bf9cd 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -511,6 +511,13 @@ def test_meta_symint(self): r = torch.empty(a0, device="meta") self.assertIsInstance(r.shape[0], SymInt) + def test_hash_size(self): + # See issue #168254 + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 2) + r = torch.empty(a0, device="meta") + self.assertRaises(TypeError, lambda: hash(r.shape)) + def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) diff --git a/test/test_fx.py b/test/test_fx.py index 71299ddb2400d..7fdd6552edc7b 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3001,7 +3001,7 @@ def forward(self, inp): for node in traced.graph.nodes: if node.op == "placeholder": ph = node - elif node.op == "call_function" and node.target == wrapped_named_tup: + elif node.op == "call_function" and node.target is wrapped_named_tup: node.update_arg(0, Pair(ph, 1.2)) node.update_kwarg("p2", Pair(3.4, ph)) call_func = node @@ -3164,7 +3164,7 @@ def forward(self, x, y): mod_false = symbolic_trace(mod, concrete_args={"y": False}) self.assertEqual(mod_true(3, True), 6) print(mod_true.code) - assert any(i.target == torch._assert for i in mod_true.graph.nodes) + assert any(i.target is torch._assert for i in mod_true.graph.nodes) with self.assertRaises(AssertionError): mod_true(3, False) self.assertEqual(mod_false(3, False), 3) @@ -4783,7 +4783,7 @@ def forward(self, x): self.assertEqual(len(gm.graph.nodes), 3) found = False for node in gm.graph.nodes: - if node.op == "call_function" and node.target == side_effect_func: + if node.op == "call_function" and node.target is side_effect_func: found = True self.assertTrue(found) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 6fe3fe2355a1e..6ed8d9f2fac51 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -601,7 +601,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -660,7 +660,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -688,7 +688,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) @@ -720,7 +720,7 @@ def forward(self, a, b): # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( - node.op == "call_function" and node.target == torch._assert + node.op == "call_function" and node.target is torch._assert for node in traced.graph.nodes ) ) diff --git a/test/test_linalg.py b/test/test_linalg.py index 7e3a1ebaa6f3a..ed3ca079748fd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -27,7 +27,7 @@ runOnRocmArch, MI300_ARCH, NAVI_ARCH, TEST_CUDA) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, - onlyCPU, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, + onlyCPU, skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyNativeDeviceTypes, dtypesIfCUDA, onlyCUDA, skipMeta, skipCUDAIfNoCusolver, skipCUDAIfNotRocm, dtypesIfMPS, largeTensorTest) from torch.testing import make_tensor @@ -2015,6 +2015,7 @@ def run_test_case(input, ord, dim, keepdim): run_test_case(input, ord, dim, keepdim) # Test degenerate shape results match numpy for linalg.norm matrix norms + @skipIf(np.lib.NumpyVersion(np.__version__) < '2.3.0', 'Numpy changed handling of degenerate inputs in 2.3.0') @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) @@ -2043,13 +2044,13 @@ def run_test_case(input, ord, dim, keepdim, should_error): S = 10 test_cases = [ # input size, p settings that cause error, dim - ((0, 0), [1, 2, inf, -1, -2, -inf], None), - ((0, S), [2, inf, -2, -inf], None), - ((S, 0), [1, 2, -1, -2], None), + ((0, 0), [-1, -2, -inf], None), + ((0, S), [-2, -inf], None), + ((S, 0), [-1, -2], None), ((S, S, 0), [], (0, 1)), ((1, S, 0), [], (0, 1)), - ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), - ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), + ((0, 0, S), [-1, -2, -inf], (0, 1)), + ((0, 0, S), [-1, -2, -inf], (1, 0)), ] for keepdim in [True, False]: @@ -2058,6 +2059,76 @@ def run_test_case(input, ord, dim, keepdim, should_error): for ord in ord_matrix: run_test_case(input, ord, dim, keepdim, ord in error_ords) + # TODO this is redundant with test_norm_matrix_degenerate_shapes above, + # remove when old numpy versions are dropped + @skipIf(np.lib.NumpyVersion(np.__version__) >= '2.3.0', 'Numpy changed handling of degenerate inputs in 2.3.0') + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_matrix_degenerate_shapes_old_numpy(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + ops = [torch.linalg.norm] + + if ord is not None and dim is not None: + ops.append(torch.linalg.matrix_norm) + + if should_error == 'both': + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + with self.assertRaises(IndexError): + op(input, ord, dim, keepdim) + elif should_error == 'np_only': + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + result = op(input, ord, dim, keepdim) + dim_ = dim + if dim_ is None: + dim_ = (0, 1) + expected_shape = list(input.shape) + if keepdim: + expected_shape[dim_[0]] = 1 + expected_shape[dim_[1]] = 1 + else: + del expected_shape[max(dim_)] + del expected_shape[min(dim_)] + expected = torch.zeros(expected_shape, dtype=dtype.to_real()) + self.assertEqual(expected, result, msg=msg) + else: + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + for op in ops: + result = op(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, + # p settings that error numpy but not torch, dim + ((0, 0), [-1, -2, -inf], [inf, 1, 2], None), + ((0, S), [-2, -inf], [inf, 2], None), + ((S, 0), [-1, -2], [1, 2], None), + ((S, S, 0), [], [], (0, 1)), + ((1, S, 0), [], [], (0, 1)), + ((0, 0, S), [-1, -2, -inf], [inf, 1, 2], (0, 1)), + ((0, 0, S), [-1, -2, -inf], [inf, 1, 2], (1, 0)), + ] + + for keepdim in [True, False]: + for input_size, error_ords, np_error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_matrix: + if ord in error_ords: + should_error = 'both' + elif ord in np_error_ords: + should_error = 'np_only' + else: + should_error = 'no' + run_test_case(input, ord, dim, keepdim, should_error) + def test_norm_fastpaths(self, device): x = torch.randn(3, 5, device=device) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 7a6585f3b63a8..ec1fc41547f83 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -8,6 +8,7 @@ from collections.abc import Callable import torch +import torch.nn.functional as F from torch.quantization._quantized_conversions import ( pack_int4_to_int8, @@ -404,7 +405,7 @@ def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major, dtype): b.requires_grad_(True) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.t(), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) @@ -456,7 +457,7 @@ def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major, dtype): if check_zero_size: offs[0] = offs[1] - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: @@ -501,7 +502,7 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): b_contig = b if b_row_major else b.transpose(-2, -1) self.assertTrue(b_contig.is_contiguous() is not strided) - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), out_dtype=dtype) gO = torch.rand_like(out) out.backward(gO) @@ -541,7 +542,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): if check_zero_size: offs[0] = offs[1] - f = torch._grouped_mm + f = F.grouped_mm out = f(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) gO = torch.rand_like(out) if not check_zero_size: @@ -559,7 +560,7 @@ def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): self.grouped_mm_helper(a, blist, gOlist, agradlist, bgradlist, outlist) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") - # TODO(future PR): enable compile for torch._grouped_mm fallback path + # TODO(future PR): enable compile for torch.nn.functional.grouped_mm fallback path @unittest.skipIf(not SM90OrLater, "Grouped gemm with compile supported on SM90") @parametrize("op", ["2d/2d", "2d/3d", "3d/2d", "3d/3d"]) @parametrize("a_row_major", [False, True]) @@ -572,7 +573,7 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) align = 16 // dtype_AB.itemsize - f_ref = torch._grouped_mm + f_ref = F.grouped_mm options = {} if max_autotune: diff --git a/test/test_mps.py b/test/test_mps.py index a84ac7d355169..51f2637e4d55e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9667,10 +9667,12 @@ def get_mps_memory_usage(): memory_footprints = [] for _ in range(100): output = F.scaled_dot_product_attention(query, key, value) + # syncronize to wait for the GPU computation to return + torch.mps.synchronize() current_mem, driver_mem = get_mps_memory_usage() memory_footprints.append((current_mem, driver_mem)) - # 5 MB different maximum allowed value(could be decreased even more) - torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=5, rtol=1) + # 1 kB different maximum allowed value + torch.testing.assert_close(memory_footprints[-1], memory_footprints[0], atol=1e-3, rtol=1e-3) def generate_qkv(self, batch: int, NH: int, q_len: int, s_len: int, head_dim: int, layout: str, dtype: torch.dtype): if layout == "contiguous": diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index c30ace4a70f5f..6ed34f2559a18 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -301,12 +301,18 @@ def test_from_numpy_no_leak_on_invalid_dtype(self): # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary # object. See https://github.com/pytorch/pytorch/issues/121138 x = np.array(b"value") + initial_refcount = sys.getrefcount(x) for _ in range(1000): try: torch.from_numpy(x) except TypeError: pass - self.assertTrue(sys.getrefcount(x) == 2) + final_refcount = sys.getrefcount(x) + self.assertEqual( + final_refcount, + initial_refcount, + f"Memory leak detected: refcount increased from {initial_refcount} to {final_refcount}", + ) @skipIfTorchDynamo("No need to test invalid dtypes that should fail by design.") @onlyCPU diff --git a/test/test_optim.py b/test/test_optim.py index de185725b5c2c..973e6d6fe6845 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1993,7 +1993,7 @@ def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info): @optims(optim_db, dtypes=[torch.float32]) def test_step_post_hook(self, device, dtype, optim_info): - def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def post_hook(opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -2025,7 +2025,7 @@ def dummy_closure(): @optims(optim_db, dtypes=[torch.float32]) def test_step_pre_hook(self, device, dtype, optim_info): - def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def pre_hook(opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -2058,19 +2058,27 @@ def dummy_closure(): @optims(optim_db, dtypes=[torch.float32]) def test_step_all_hooks(self, device, dtype, optim_info): - def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def global_pre_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(0) - def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def global_post_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(5) - def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def local_pre_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(1) - def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): + def local_post_hook( + opt: Optimizer, args: tuple[Any, ...], kwargs: dict[Any, Any] + ): nonlocal data data.append(2) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 515ce435b72a7..359236602e61e 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -299,8 +299,12 @@ def test_finalizer(self): lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 lib.define("foo123(Tensor x) -> Tensor") - # 1 for `lib`, 1 for sys.getrefcount - self.assertEqual(sys.getrefcount(lib), 2) + # 1 for `lib`, 1 for sys.getrefcount' for previous python version (<=3.12) + # In Python 3.13+, sys.getrefcount() was optimized to not create + # a temporary reference, so expected counts are 1 less than before + expected_refcount = 1 if sys.version_info >= (3, 14) else 2 + self.assertEqual(sys.getrefcount(lib), expected_refcount) + # We gained an additional reference that gets cleared when the finalizer runs self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1) # 1 for `lib` @@ -318,7 +322,7 @@ def foo123(x): saved_op_impls = lib._op_impls # del will definitely work if the following passes - self.assertEqual(sys.getrefcount(lib), 2) + self.assertEqual(sys.getrefcount(lib), expected_refcount) del lib # 1 for saved_op_impls @@ -326,7 +330,7 @@ def foo123(x): # This function should be the last user of lib._op_impls: # - lib should not have a reference anymore (it was del'ed) # - lib's finalizer should not have a reference anymore - self.assertEqual(sys.getrefcount(saved_op_impls), 2) + self.assertEqual(sys.getrefcount(saved_op_impls), expected_refcount) self.assertTrue(key not in torch.library._impls) diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 94d6ece0f6369..25c4efe35a1ab 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -11,7 +11,14 @@ import torch -from torch.nn.functional import pad, scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType +from torch.nn.functional import ( + grouped_mm, + pad, + scaled_mm, + scaled_grouped_mm, + ScalingType, + SwizzleType, +) from torch.testing._internal.common_cuda import ( IS_SM90, _get_torch_cuda_version, @@ -785,7 +792,7 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): ) # bf16 reference output - y_bf16 = torch._grouped_mm( + y_bf16 = grouped_mm( # Note: Reference result should be on reconstructed, not original values. # as-in float(fp4(t)) not t itself. xh, wh.t(), offs=input_group_end_offsets, out_dtype=torch.bfloat16 @@ -931,7 +938,7 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): # Compute reference bf16 grouped gemm. # Note: Reference result should be on reconstructed, not original values. # as-in float(fp4(t)) not t itself. - y_bf16 = torch._grouped_mm( + y_bf16 = grouped_mm( xh, wh.transpose(-2, -1), offs=input_group_end_offsets, diff --git a/test/test_sparse.py b/test/test_sparse.py index 5ac9b1542aa72..42ebfbff83337 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -1001,7 +1001,6 @@ def test_shape(sparse_dims, nnz, with_size): @coalescedonoff @dtypes(torch.double, torch.cdouble) @dtypesIfMPS(torch.float32, torch.complex64) - @expectedFailureMPS @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") @gradcheck_semantics() def test_permute(self, device, dtype, coalesced, gradcheck): @@ -1041,7 +1040,8 @@ def test_shape(sparse_dims, nnz, with_size): else: self.assertFalse(s_permuted.is_coalesced()) - gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_()) + kwargs = {"eps": 1e-4} if device == "mps:0" else {} + gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_(), **kwargs) else: # otherwise check if exception is thrown fail_message = "transpositions between sparse and dense dimensions are not allowed" @@ -1686,7 +1686,6 @@ def fn(S, D1, D2, beta=beta, alpha=alpha): test_shape(7, 8, 9, 20, True, (1, 1)) @coalescedonoff - @expectedFailureMPS @dtypes(torch.double) @dtypesIfMPS(torch.float32) @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") @@ -1704,7 +1703,9 @@ def test_shape(d1, d2, d3, nnz, transposed): def fn(S, D): return torch.sparse.mm(S, D) - gradcheck(fn, (S, D), masked=True) + + kwargs = {"eps": 1e-4, "atol": 2e-5} if device == "mps:0" else {} + gradcheck(fn, (S, D), masked=True, **kwargs) test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) @@ -3924,7 +3925,6 @@ def _test_mul_skips(self, device, dtype, coalesced): self.skipTest(f"Test with dtype={dtype}, device={device} runs only with coalesced inputs") @coalescedonoff - @expectedFailureMPS # NOTE: addcmul_out is not implemented for bool. @dtypes(*all_types_and_complex_and(torch.bfloat16, torch.float16)) @dtypesIfMPS(*all_mps_types()) diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 6284be2aebe9e..522a82cf9a222 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -357,6 +357,9 @@ def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], allowed_dtypes=(torch.cfloat, torch.cdouble)) + @toleranceOverride({ + torch.cfloat : tol(2e-4, 1.3e-6), + }) def test_reference_nd(self, device, dtype, op): if op.ref is None: raise unittest.SkipTest("No reference implementation") diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index cf2c836486c80..c8ad8276a116b 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -891,7 +891,7 @@ def test_threshold(x, y): torch.manual_seed(0) for torch_fn, dev, data_type in fn_dev_dtype: - if torch_fn == test_lgamma and dev == "cuda": + if torch_fn is test_lgamma and dev == "cuda": # lgamma_cuda does not support BF16 continue rand_a = torch.rand(1024, dtype=data_type, device=dev) diff --git a/test/test_torch.py b/test/test_torch.py index 01c6fb39a5a2a..66b2002a36d1e 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2987,7 +2987,7 @@ def filter_shape(shape, dim): t_np = t.cpu().numpy() actual = torch.gradient(t, spacing=spacing, dim=dims, edge_order=edge_order) - if space_fn == create_coordinate_tensors and spacing[0].device != 'cpu': + if space_fn is create_coordinate_tensors and spacing[0].device != 'cpu': spacing = [space.cpu().detach().numpy() for space in spacing] expected = np.gradient(t_np, *self._wrap_to_list(spacing), axis=dims, edge_order=edge_order) actual, expected = self._inf_nan_preprocess(list(actual), self._wrap_to_list(expected)) @@ -7464,6 +7464,9 @@ def test_parsing_intlist(self): "missing 1 required positional arguments", lambda: torch.tensor().new_zeros((5, 5), 0)) + # ensure ones() throws an error when extra positional (non-keyword) arguments are given. + self.assertRaises(TypeError, lambda: torch.ones((3, 3), torch.float32)) + def test_from_buffer(self): a = bytearray([1, 2, 3, 4]) self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index 1a8371eee6345..3b3186a157895 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -110,6 +110,16 @@ def forward_sdpa( return self.out_proj(attn_out) +def pack_sequences(seqs, device): + x_packed = torch.cat(seqs, dim=0) + seq_lens = torch.tensor([len(s) for s in seqs], device=device) + cu_seq = torch.zeros(len(seqs) + 1, device=device, dtype=torch.int32) + cu_seq[1:] = seq_lens.cumsum(0) + max_len = seq_lens.max().item() + + return x_packed, cu_seq, max_len + + def create_variable_length_batch( shape: VarlenShape, device: torch.device, dtype: torch.dtype ): @@ -119,16 +129,15 @@ def create_variable_length_batch( seq_lengths.append(min(length, shape.max_seq_len)) seq_lengths = torch.tensor(seq_lengths, device=device) - total_tokens = seq_lengths.sum().item() - - x_packed = torch.randn( - total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True - ) - cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32) - cu_seq[1:] = seq_lengths.cumsum(0) + sequences = [ + torch.randn( + seq_len, shape.embed_dim, device=device, dtype=dtype, requires_grad=True + ) + for seq_len in seq_lengths + ] - max_len = seq_lengths.max().item() + x_packed, cu_seq, max_len = pack_sequences(sequences, device) x_padded = torch.zeros( shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype ) @@ -146,7 +155,6 @@ def create_variable_length_batch( "x_packed": x_packed, "x_padded": x_padded, "max_len": max_len, - "total_tokens": total_tokens, } @@ -428,6 +436,133 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal): start_idx = end_idx + @skipIfRocm(msg="ROCM does not support variable length attention") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + @parametrize("is_causal", [False, True]) + @parametrize("num_perms", [1, 3, 5]) + def test_batch_invariance(self, device, dtype, is_causal, num_perms): + torch.manual_seed(42) + + batch_size, max_seq_len = 4, 128 + + seq_lengths = [] + for _ in range(batch_size): + length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 + seq_lengths.append(min(length, max_seq_len)) + + sequences_qkv = [ + [ + torch.testing.make_tensor( + (seq_len, 2, 128), device=device, dtype=dtype, requires_grad=True + ) + for _ in range(3) + ] + for seq_len in seq_lengths + ] + sequences_q, sequences_k, sequences_v = map(list, zip(*sequences_qkv)) + + q_packed_orig = torch.cat(sequences_q, dim=0) + k_packed_orig = torch.cat(sequences_k, dim=0) + v_packed_orig = torch.cat(sequences_v, dim=0) + + seq_lens = torch.tensor(seq_lengths, device=device) + cu_seq_orig = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seq_orig[1:] = seq_lens.cumsum(0) + + original_output = varlen_attn( + q_packed_orig, + k_packed_orig, + v_packed_orig, + cu_seq_orig, + cu_seq_orig, + max_seq_len, + max_seq_len, + is_causal, + ) + + original_grad_out = torch.randn_like(original_output) + original_grads = torch.autograd.grad( + outputs=original_output, + inputs=[q_packed_orig, k_packed_orig, v_packed_orig], + grad_outputs=original_grad_out, + ) + + for _ in range(num_perms): + perm = torch.randperm(batch_size) + permuted_sequences_q = [sequences_q[perm[i]] for i in range(batch_size)] + permuted_sequences_k = [sequences_k[perm[i]] for i in range(batch_size)] + permuted_sequences_v = [sequences_v[perm[i]] for i in range(batch_size)] + + q_packed_perm = torch.cat(permuted_sequences_q, dim=0) + k_packed_perm = torch.cat(permuted_sequences_k, dim=0) + v_packed_perm = torch.cat(permuted_sequences_v, dim=0) + + permuted_seq_lens = torch.tensor( + [seq_lengths[perm[i]] for i in range(batch_size)], device=device + ) + cu_seq_perm = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + cu_seq_perm[1:] = permuted_seq_lens.cumsum(0) + + permuted_output = varlen_attn( + q_packed_perm, + k_packed_perm, + v_packed_perm, + cu_seq_perm, + cu_seq_perm, + max_seq_len, + max_seq_len, + is_causal, + ) + + for i in range(batch_size): + orig_idx = perm[i].item() + + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + orig_seq_output = original_output[orig_start:orig_end] + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + perm_seq_output = permuted_output[perm_start:perm_end] + + self.assertEqual(orig_seq_output, perm_seq_output) + + permuted_grad_out = torch.zeros_like(permuted_output) + for i in range(batch_size): + orig_idx = perm[i].item() + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + + permuted_grad_out[perm_start:perm_end] = original_grad_out[ + orig_start:orig_end + ] + + permuted_grads = torch.autograd.grad( + outputs=permuted_output, + inputs=[q_packed_perm, k_packed_perm, v_packed_perm], + grad_outputs=permuted_grad_out, + ) + + for original_grad, permuted_grad in zip(original_grads, permuted_grads): + for i in range(batch_size): + orig_idx = perm[i].item() + + orig_start = cu_seq_orig[orig_idx].item() + orig_end = cu_seq_orig[orig_idx + 1].item() + orig_seq_grad = original_grad[orig_start:orig_end] + + perm_start = cu_seq_perm[i].item() + perm_end = cu_seq_perm[i + 1].item() + perm_seq_grad = permuted_grad[perm_start:perm_end] + + self.assertEqual(orig_seq_grad, perm_seq_grad) + device_types = ("cuda",) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index 980439b7a6967..2a127241c49bd 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -1183,6 +1183,8 @@ def test_reshape(self, device): self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + # ensure reshape() throws an error if extra positional arguments are given. + self.assertRaises(TypeError, lambda: x.reshape((9,), torch.float32)) y = torch.randn(4, 4, 4, device=device)[:, 0, :] # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape @@ -1726,6 +1728,9 @@ def can_broadcast(s0, s1): r"must match the existing size \(\d\)", ): torch.broadcast_to(t, s1) + # ensure broadcast_to() throws an error when extra positional arguments are given. + t = torch.tensor([1, 2, 3]) + self.assertRaises(TypeError, lambda: t.broadcast_to((3, 3), torch.float32)) def test_view(self, device): tensor = torch.rand(15, device=device) @@ -1812,6 +1817,11 @@ def test_view(self, device): self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) + # ensure view() throws an error if extra positional arguments are given. + self.assertRaises( + TypeError, lambda: tensor.view((tensor.numel(),), torch.float32) + ) + @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_reshape_view_semantics(self, device, dtype): tensor = make_tensor((15, 4), dtype=dtype, device=device) diff --git a/test/torch_np/numpy_tests/core/test_numeric.py b/test/torch_np/numpy_tests/core/test_numeric.py index c6b2d14aef6dc..209119ee8f012 100644 --- a/test/torch_np/numpy_tests/core/test_numeric.py +++ b/test/torch_np/numpy_tests/core/test_numeric.py @@ -2382,7 +2382,7 @@ def test_dtype_str_bytes(self, likefunc, dtype): # Regression test for gh-19860 a = np.arange(16).reshape(2, 8) b = a[:, ::2] # Ensure b is not contiguous. - kwargs = {"fill_value": ""} if likefunc == np.full_like else {} + kwargs = {"fill_value": ""} if likefunc is np.full_like else {} result = likefunc(b, dtype=dtype, **kwargs) if dtype is str: assert result.strides == (16, 4) diff --git a/third_party/cpuinfo b/third_party/cpuinfo index 5e3d2445e6a84..f858c30bcb16f 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit 5e3d2445e6a84d9599bee2bf78edbb4d80865e1d +Subproject commit f858c30bcb16f8effd5ff46996f0514539e17abc diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4796153f24f05..e1a518aca6704 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -421,7 +421,7 @@ # inplace or out-variants) # If the function does not modify its arguments, we also check the following properties # pertaining to its output: -# 2) Its TensorImpl has use_count of 1 +# 2) Its TensorImpl has use_count of 1 (or 2 if it has a PyObject) # 3) If the function is a view function, it has the same StorageImpl as that of # the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 # @@ -496,10 +496,10 @@ """ ) -ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate( +ENFORCE_TENSOR_IMPL_USE_COUNT = CodeTemplate( """\ if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) - TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() == expected_fresh_use_count(${tensor_name}), "function: ${fn_name}"); """ ) @@ -1664,7 +1664,7 @@ def check_tensorimpl_and_storage( if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ - ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute( + ENFORCE_TENSOR_IMPL_USE_COUNT.substitute( tensor_name=ret_name, fn_name=type_wrapper_name(f) ) ] diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 23976a48473a3..d1de108283b11 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -47,6 +47,18 @@ namespace{ meta->grad_accumulator_.reset(); } } +[[maybe_unused]] size_t expected_fresh_use_count(const Variable& self) { + if (!self.defined()) { + // An UndefinedTensorImpl always has a use count of 0 + return 0; + } + if (self.unsafeGetTensorImpl()->pyobj_slot()->load_pyobj() != nullptr) { + // A TensorImpl with a Python object has a use count of 2 + return 2; + } + // A fresh TensorImpl (with no PyObject) has a use count of 1 + return 1; +} } namespace { diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index d1004cdc3a955..ff19dadcfdc02 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -119,6 +119,7 @@ def get_torch_version(sha: str | None = None) -> str: ) parser.add_argument("--cuda-version", "--cuda_version", type=str) parser.add_argument("--hip-version", "--hip_version", type=str) + parser.add_argument("--rocm-version", "--rocm_version", type=str) parser.add_argument("--xpu-version", "--xpu_version", type=str) args = parser.parse_args() @@ -126,6 +127,7 @@ def get_torch_version(sha: str | None = None) -> str: assert args.is_debug is not None args.cuda_version = None if args.cuda_version == "" else args.cuda_version args.hip_version = None if args.hip_version == "" else args.hip_version + args.rocm_version = None if args.rocm_version == "" else args.rocm_version args.xpu_version = None if args.xpu_version == "" else args.xpu_version pytorch_root = Path(__file__).parent.parent @@ -141,7 +143,7 @@ def get_torch_version(sha: str | None = None) -> str: with open(version_path, "w") as f: f.write("from typing import Optional\n\n") f.write( - "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n" + "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'rocm', 'xpu']\n" ) f.write(f"__version__ = '{version}'\n") # NB: This is not 100% accurate, because you could have built the @@ -151,4 +153,5 @@ def get_torch_version(sha: str | None = None) -> str: f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n") f.write(f"git_version = {repr(sha)}\n") f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n") + f.write(f"rocm: Optional[str] = {repr(args.rocm_version)}\n") f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n") diff --git a/tools/linter/adapters/pylint_linter.py b/tools/linter/adapters/pylint_linter.py new file mode 100644 index 0000000000000..c4051272b1df6 --- /dev/null +++ b/tools/linter/adapters/pylint_linter.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import argparse +import json +import logging +import os +import subprocess +import sys +import time +from enum import Enum +from pathlib import Path +from typing import NamedTuple + + +class LintSeverity(str, Enum): + ERROR = "error" + WARNING = "warning" + DISABLED = "disabled" + + +class LintMessage(NamedTuple): + path: str | None + line: int | None + char: int | None + code: str + severity: LintSeverity + name: str + original: str | None + replacement: str | None + description: str | None + + +def run_command( + args: list[str], +) -> subprocess.CompletedProcess[bytes]: + logging.debug("$ %s", " ".join(args)) + start_time = time.monotonic() + try: + return subprocess.run( + args, + capture_output=True, + check=False, + ) + finally: + end_time = time.monotonic() + logging.debug("took %dms", (end_time - start_time) * 1000) + + +def check_pylint_installed(code: str) -> list[LintMessage]: + cmd = [sys.executable, "-mpylint", "--version"] + try: + subprocess.run(cmd, check=True, capture_output=True) + return [] + except subprocess.CalledProcessError as e: + msg = e.stderr.decode(errors="replace") + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=f"Could not run '{' '.join(cmd)}': {msg}", + ) + ] + + +def in_github_actions() -> bool: + return bool(os.getenv("GITHUB_ACTIONS")) + + +def check_files( + filenames: list[str], + config: str, + code: str, +) -> list[LintMessage]: + try: + proc = run_command( + ["pylint", f"--rcfile={config}", "-f", "json"] + filenames, + ) + except OSError as err: + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=(f"Failed due to {err.__class__.__name__}:\n{err}"), + ) + ] + if proc.returncode == 32: + stderr = str(proc.stderr, "utf-8").strip() + return [ + LintMessage( + path=None, + line=None, + char=None, + code=code, + severity=LintSeverity.ERROR, + name="command-failed", + original=None, + replacement=None, + description=stderr, + ) + ] + stdout = str(proc.stdout, "utf-8").strip() + errors = json.loads(stdout) + + return [ + LintMessage( + path=error["path"], + name=error["message-id"], + description=error["message"], + line=int(error["line"]), + char=int(error["column"]), + code=code, + severity=LintSeverity.ERROR, + original=None, + replacement=None, + ) + for error in errors + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="pylint wrapper linter.", + fromfile_prefix_chars="@", + ) + parser.add_argument( + "--config", + required=True, + help="path to a pylintrc config file", + ) + parser.add_argument( + "--code", + default="PYLINT", + help="the code this lint should report as", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="verbose logging", + ) + parser.add_argument( + "filenames", + nargs="+", + help="paths to lint", + ) + args = parser.parse_args() + + logging.basicConfig( + format="<%(threadName)s:%(levelname)s> %(message)s", + level=logging.NOTSET + if args.verbose + else logging.DEBUG + if len(args.filenames) < 1000 + else logging.INFO, + stream=sys.stderr, + ) + + filenames: set[str] = set() + + # If a stub file exists, have pylint check it instead of the original file, in + # accordance with PEP-484 (see https://www.python.org/dev/peps/pep-0484/#stub-files) + for filename in args.filenames: + if filename.endswith(".pyi"): + filenames.add(filename) + continue + + stub_filename = filename.replace(".py", ".pyi") + if Path(stub_filename).exists(): + filenames.add(stub_filename) + else: + filenames.add(filename) + + lint_messages = check_pylint_installed(args.code) + check_files( + list(filenames), args.config, args.code + ) + for lint_message in lint_messages: + print(json.dumps(lint_message._asdict()), flush=True) + + +if __name__ == "__main__": + main() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index d92b9e19a76c5..c7a43f30e49d5 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -490,7 +490,8 @@ add_custom_target( "${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py" --is-debug=${TORCH_VERSION_DEBUG} --cuda-version=${CUDA_VERSION} - --hip-version=${ROCM_VERSION_DEV} + --hip-version=${HIP_VERSION_CLEAN} + --rocm-version=${ROCM_VERSION_DEV} --xpu-version=${SYCL_COMPILER_VERSION} BYPRODUCTS ${TORCH_SRC_DIR}/version.py COMMENT "Regenerating version file..." diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 752bd594d066f..477b35b1811e4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -100,7 +100,9 @@ class Logger: def _set_static_graph(self) -> None: ... class _WorkerServer: - def __init__(self, socket_path: str) -> None: ... + port: int + + def __init__(self, host_or_file: str, port: int = ...) -> None: ... def shutdown(self) -> None: ... def get_debug_level(): ... @@ -206,6 +208,7 @@ class Store: desired_value: str, ) -> bytes: ... def delete_key(self, key: str) -> bool: ... + def multi_get(self, keys: list[str]) -> list[bytes]: ... def num_keys(self) -> int: ... def set_timeout(self, timeout: timedelta): ... @overload @@ -872,3 +875,15 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... + +class _Request: + def body(self) -> bytes: ... + def get_param(self, str) -> str: ... + +class _Response: + def set_content(self, content: str | bytes, content_type: str) -> None: ... + def set_status(self, status: int) -> None: ... + +def _register_handler( + name: str, handler: Callable[[_Request, _Response], None] +) -> None: ... diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 117795db5ac3e..3c3a18ed4e063 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -1,5 +1,6 @@ import enum import types +from collections.abc import Callable from typing import Optional, overload from torch._dynamo.guards import GuardManagerWrapper @@ -27,6 +28,7 @@ class _CacheEntry: compile_id: CompileId # If we run into circular issues, just use object guard_manager: GuardManagerWrapper + backend: Callable next: _CacheEntry | None class _PrecompileEntry: diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index d60d89a6a4796..de12af50c1855 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -60,6 +60,7 @@ class _ExperimentalConfig: verbose: bool = ..., performance_events: list[str] = ..., enable_cuda_sync_events: bool = ..., + profile_all_threads: bool = ..., ) -> None: ... class ProfilerConfig: diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 1c44de2c1ad1e..87dc80e99bd79 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -99,6 +99,7 @@ always_optimize_code_objects, Constraint, dynamo_tls, + innermost_fn, skip_code, TorchPatcher, ) @@ -933,7 +934,11 @@ class GraphRuntimeEnv: argdefs: Optional[tuple[Any, ...]] def forward_callable( - self, backend_id: str, compiled_fn: Callable[..., Any] + self, + backend_id: str, + compiled_fn: Callable[..., Any], + *, + extra_globals: Optional[dict[str, Any]] = None, ) -> Callable[..., Any]: import_sources = { alias: importlib.import_module(module_name) @@ -942,6 +947,7 @@ def forward_callable( f_globals = { **import_sources, **self.used_globals, + **(extra_globals or {}), backend_id: compiled_fn, } return types.FunctionType( @@ -1026,13 +1032,16 @@ def forward_callable( self, *, compiled_fn: Optional[Callable[..., Any]] = None, + extra_globals: Optional[dict[str, Any]] = None, ) -> Callable[..., Any]: runtime_env = self.graph_capture_output.get_runtime_env() assert self.backend_input is not None backend_id = self.backend_input.backend_id # pyrefly: ignore [not-callable] compiled_fn = compiled_fn or self.backend_input.graph_module - return runtime_env.forward_callable(backend_id, compiled_fn) + return runtime_env.forward_callable( + backend_id, compiled_fn, extra_globals=extra_globals + ) def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: @@ -1566,7 +1575,9 @@ def count_args(code: CodeType) -> int: # Check recompilations recompile_reason: Optional[str] = None if is_recompilation(cache_size) and frame: - reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame) + reasons = get_and_maybe_log_recompilation_reasons( + cache_entry, frame, innermost_fn(compiler_fn) + ) recompile_reason = ( "Unable to find recompilation reasons" if not reasons else reasons[0] ) @@ -1574,7 +1585,7 @@ def count_args(code: CodeType) -> int: inline_inbuilt_nn_modules_candidate = False if not config.inline_inbuilt_nn_modules and frame: inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons( - cache_entry, frame, skip_logging=True + cache_entry, frame, innermost_fn(compiler_fn), skip_logging=True ) inbuilt_nn_recompile_reason = ( None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0] diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 0956facde2559..4253fa031d2ec 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -251,7 +251,7 @@ def fail_callback( cache_entries = _debug_get_cache_entry_list(frame.f_code) if cache_entries: reasons = get_and_maybe_log_recompilation_reasons( - cache_entries[0], frame, skip_logging=True + cache_entries[0], frame, innermost_fn(callback), skip_logging=True ) if reasons: failures = textwrap.indent("\n".join(reasons), "- ") @@ -787,6 +787,11 @@ def get_compiler_config() -> Any: def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: from torch._dynamo.aot_compile import aot_compile_fullgraph + if torch._inductor.config.force_disable_caches: + raise RuntimeError( + "Cannot precompile with torch._inductor.config.force_disable_caches=True; caching is required." + ) + if not self.fullgraph: raise RuntimeError( "Graph breaks are not supported with aot compile. Please use torch.compile(fullgraph=True)." diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index f11c78bdaa49e..5b0e8a402dd96 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -198,24 +198,20 @@ class RecompileError(TorchDynamoException): class ArgsMismatchError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class AttributeMutationError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class InfiniteGeneratorError(Unsupported): # Raised when the number of yielded values is greater than MAX_ITERATOR_LIMIT - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class SideEffectsError(Unsupported): - def __init__(self, msg: str) -> None: - super().__init__(msg) + pass class CondOpArgsMismatchError(ArgsMismatchError): @@ -223,9 +219,6 @@ class CondOpArgsMismatchError(ArgsMismatchError): Internal error from cond() due to arguments mismatch. """ - def __init__(self, msg: str) -> None: - super().__init__(msg) - class UserErrorType(Enum): DYNAMIC_CONTROL_FLOW = auto() diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 84641d66c6bd8..548a4b279b860 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -501,6 +501,7 @@ def pytreeify( torch._dynamo.eval_frame.check_user_input_output( flat_real_args[1 if root else 0 :], UserErrorType.INVALID_INPUT ) + f_globals = out.graph_capture_output.f_globals class Yield(Exception): pass @@ -522,7 +523,9 @@ def backend_dummy(*example_inputs): raise Yield try: - out.forward_callable(compiled_fn=backend_dummy)(*args, **kwargs) + out.forward_callable( + compiled_fn=backend_dummy, extra_globals=f_globals + )(*args, **kwargs) except Yield: assert self.gm_inputs is not None return self.gm_inputs @@ -557,7 +560,9 @@ def backend_dummy(*example_inputs): for i in range(self.num_outputs) ] - results = out.forward_callable(compiled_fn=backend_dummy)(*args, **kwargs) + results = out.forward_callable( + compiled_fn=backend_dummy, extra_globals=f_globals + )(*args, **kwargs) ret, self.out_spec = pytree.tree_flatten(results) return ret @@ -606,7 +611,6 @@ def dynamo_graph_capture_for_export( def inner(*args: Any, **kwargs: Any) -> Any: assert not torch._dynamo.config.install_free_tensors with ( - torch._dynamo.config.patch(replay_side_effects=False), torch._dynamo.config.patch(side_effect_replay_policy="warn"), get_metrics_context(), dynamo_timed("fullgraph_capture"), diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index a75118f9e5032..cf621921cd59b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -130,6 +130,7 @@ ChainedSource, ClosureSource, CodeSource, + CollectionsSource, ConstantSource, ConstDictKeySource, CurrentStreamSource, @@ -1442,6 +1443,13 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CollectionsSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: collections, + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, TorchFunctionModeStackSource): out = root_guard_manager.lambda_manager( python_lambda=lambda _: get_torch_function_mode_stack_at( @@ -3677,6 +3685,7 @@ def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry: self.guard_manager, output_graph.local_scope, CompileContext.current_compile_id(), + backend=None, # no need to set this because we are trying to find the offending guard entry ) raise AssertionError( "Guard failed on the same frame it was created. This is a bug - please create an issue." @@ -4294,6 +4303,7 @@ def get_guard_fail_reason_helper( guard_manager: GuardManagerWrapper, f_locals: dict[str, object], compile_id: Optional[CompileId], + backend: Optional[Callable], ) -> str: """ Return the reason why `guard_manager` failed. @@ -4306,6 +4316,10 @@ def get_guard_fail_reason_helper( scope.update(guard_manager.closure_vars) reasons: list[str] = [] + cache_entry_backend = None + if guard_manager.cache_entry: + cache_entry_backend = guard_manager.cache_entry.backend + no_tensor_aliasing_check_failed = False verbose_code_parts: list[str] = [] @@ -4328,6 +4342,24 @@ def get_guard_fail_reason_helper( else: reasons = verbose_code_parts verbose_code_parts = [] + elif cache_entry_backend != backend: + # None of the guard entries failed - a backend match issue + reason = ( + "BACKEND_MATCH failure: torch.compile detected different backend callables." + " If this is unexpected, wrap your backend in functools.partial (or reuse the" + " same cached backend) to avoid creating a new backend function each time." + " More details: https://github.com/pytorch/pytorch/issues/168373" + ) + reasons.append(reason) + else: + # Unexpected recompilation - points to a bug + reason = ( + "Unexpected recompilation: runtime guards failed even though they passed" + " during recompilation-reason analysis." + " Please open an issue with a minimal repro:" + " https://github.com/pytorch/pytorch" + ) + reasons.append(reason) if no_tensor_aliasing_check_failed: reasons = recompilation_reason_for_no_tensor_aliasing_guard( @@ -4364,11 +4396,14 @@ def get_guard_fail_reason( code: types.CodeType, f_locals: dict[str, object], compile_id: CompileId, + backend: Callable, skip_logging: bool = False, ) -> str: if isinstance(guard_manager, DeletedGuardManagerWrapper): return f"{compile_id}: {guard_manager.invalidation_reason}" - reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) + reason_str = get_guard_fail_reason_helper( + guard_manager, f_locals, compile_id, backend + ) if skip_logging: return reason_str guard_failures[orig_code_map[code]].append(reason_str) @@ -4389,6 +4424,7 @@ def get_guard_fail_reason( def get_and_maybe_log_recompilation_reasons( cache_entry: Optional[CacheEntry], frame: DynamoFrameType, + backend: Callable, skip_logging: bool = False, ) -> list[str]: """ @@ -4403,6 +4439,7 @@ def get_and_maybe_log_recompilation_reasons( cache_entry.code, frame.f_locals, cache_entry.compile_id, + backend, skip_logging, ) if reason: diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 5be6b8ccbf41d..a5a69cd177c27 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -1003,6 +1003,33 @@ def guard_source(self) -> GuardSource: return GuardSource.GLOBAL +@dataclasses.dataclass(frozen=True) +class CollectionsSource(Source): + """Points to the actual `collections` module - used instead of GlobalSource + in case the user has overridden `collections` in their local namespace""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + from .guards import GuardBuilder, install_guard + + install_guard(self.make_guard(GuardBuilder.ID_MATCH)) + + def name(self) -> str: + return "__import__('collections')" + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.extend_output( + [ + codegen.create_load_const(0), # level + create_build_tuple(0), # fromlist + codegen.create_import_name("collections"), + ] + ) + + def guard_source(self) -> GuardSource: + return GuardSource.GLOBAL + + @dataclasses.dataclass(frozen=True) class TorchFunctionModeStackSource(Source): ind: int diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 18f053a2ca675..f401b9d6178b9 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2050,22 +2050,23 @@ def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None: def FOR_ITER(self, inst: Instruction) -> None: it = self.pop().realize() + self.push(it) try: val = it.next_variable(self) - self.push(it) self.push(val) except (StopIteration, exc.ObservedUserStopIteration) as e: if isinstance(e, exc.ObservedUserStopIteration): exc.handle_observed_exception(self) - # leave iterator upon exhaustion in 3.12 if sys.version_info >= (3, 12): # CPython 3.12 actually jumps to the instruction after the END_FOR # and performs the action of END_FOR as part of FOR_ITER. We jump # to the END_FOR and run it, so we need to make sure 2 values are # on the stack for it to pop. - self.push(it) self.push(ConstantVariable.create(None)) + else: + # pop the iterator in Python < 3.12 + self.pop() self.jump(inst) def _create_exception_type(self, val: VariableTracker) -> VariableTracker: @@ -4166,6 +4167,33 @@ def speculate(self) -> SpeculationEntry: self.instructions[self.instruction_pointer - 1], ) + def _make_frame_loc( + self, filename: str, lineno: Optional[int], fallback_lineno: int + ) -> tuple[str, int]: + if lineno is None or lineno < 0: + return (filename, fallback_lineno) + return (filename, lineno) + + def _get_frame_loc_chain( + self, frame_loc: tuple[str, int] + ) -> tuple[tuple[str, int], ...]: + frame_loc_chain_list: list[tuple[str, int]] = [] + + if config.nested_graph_breaks: + current_tx: Optional[InstructionTranslatorBase] = self.parent + while current_tx is not None: + parent_frame_loc = self._make_frame_loc( + current_tx.f_code.co_filename, + current_tx.lineno, + current_tx.f_code.co_firstlineno, + ) + frame_loc_chain_list.append(parent_frame_loc) + current_tx = current_tx.parent + + frame_loc_chain_list.reverse() + frame_loc_chain_list.append(frame_loc) + return tuple(frame_loc_chain_list) + def log_graph_break( self, code_options: dict[str, Any], @@ -4176,14 +4204,25 @@ def log_graph_break( user_stack = torch._guards.TracingContext.extract_stack() try: - frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + if config.nested_graph_breaks and self.parent is not None: + frame_loc = self._make_frame_loc( + self.f_code.co_filename, + self.lineno, + self.f_code.co_firstlineno, + ) + else: + frame_loc = self._make_frame_loc( + user_stack[-1].filename, + user_stack[-1].lineno, + 0, + ) except IndexError: # first instruction frame_loc = ( code_options["co_filename"], code_options["co_firstlineno"], ) - + frame_loc_chain = self._get_frame_loc_chain(frame_loc) stack_above_dynamo_formatted = "" if config.verbose: stack_above_dynamo = get_stack_above_dynamo() @@ -4228,7 +4267,7 @@ def log_graph_break( if ( graph_break_log.isEnabledFor(logging.DEBUG) and not explain - and graph_break_dup_warning_checker.add(frame_loc) + and graph_break_dup_warning_checker.add(frame_loc_chain) # type: ignore[arg-type] ): # This log line MUST contain the string "Graph break in user code", # This log line is exercised from diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 97a3946b48bde..083c8b1f93807 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -64,6 +64,8 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, ReparametrizeModuleCallVariable, SkipFunctionVariable, TorchInGraphFunctionVariable, @@ -378,6 +380,8 @@ f"torch/testing/_internal/distributed/_tensor/common_dtensor.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, "torch/testing/_internal/common_distributed.py#forward": UserFunctionVariable, f"torch/testing/_internal/common_distributed.py#{TORCH_DYNAMO_RESUME_IN_PREFIX}": UserFunctionVariable, + "torch.utils._pytree._get_node_type": PyTreeGetNodeTypeFunctionVariable, + "torch.utils._pytree.tree_is_leaf": PyTreeTreeIsLeafFunctionVariable, } diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 74165b30bb2f0..439ce274b7ce6 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -64,6 +64,8 @@ LocalGeneratorObjectVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, + PyTreeGetNodeTypeFunctionVariable, + PyTreeTreeIsLeafFunctionVariable, SkipFunctionVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 4e248320e60b6..617f787e43d8a 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -14,6 +14,7 @@ """ import collections +import logging from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from enum import Enum from typing import Any, NoReturn, Optional, TYPE_CHECKING @@ -32,6 +33,11 @@ if TYPE_CHECKING: from ..codegen import PyCodegen from ..symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + from .functions import UserFunctionVariable + + +log = logging.getLogger(__name__) class SourceType(Enum): @@ -151,9 +157,6 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ - def __init__(self, typ: SourceType) -> None: - super().__init__(typ) - class AttributeMutationExisting(AttributeMutation): """ @@ -443,7 +446,7 @@ def force_apply_to_var_sequence( for v in self.unpack_var_sequence(tx): fn(v) - def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker": + def call_obj_hasattr(self, tx: Any, name: str) -> "ConstantVariable": unimplemented( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", @@ -559,6 +562,81 @@ def call_method( hints=hints, ) + def call_tree_map( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Performance optimization to implement optree.tree_map faster than tracing it""" + is_leaf_var = tree_map_kwargs.get("is_leaf") + if is_leaf_var is not None and not ( + is_leaf_var.is_python_constant() + and is_leaf_var.as_python_constant() is None + ): + pred_result = is_leaf_var.call_function(tx, [self], {}) + try: + leaf_decision = pred_result.as_python_constant() + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + if leaf_decision: + return map_fn.call_function(tx, [self, *rest], {}) + + return self.call_tree_map_branch( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def call_tree_map_branch( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)""" + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _tree_map_fallback( + self, + tx: Any, + tree_map_fn: "UserFunctionVariable", + map_fn: "VariableTracker", + rest: Sequence["VariableTracker"], + tree_map_kwargs: dict[str, "VariableTracker"], + ) -> "VariableTracker": + tree_map_fn_copy = tree_map_fn.clone() + tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] + log.debug( + "tree_map fastpath fallback triggered for %s (rest=%s, kwargs=%s)", + self, + rest, + tree_map_kwargs, + ) + return tree_map_fn_copy.call_function( + tx, + [map_fn, self, *rest], + tree_map_kwargs, + ) + def set_name_hint(self, name: str) -> None: pass diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 746db0f3dfd62..ae6678628634a 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2280,11 +2280,11 @@ def call_zip( "1 kwargs (`strict`)", f"{len(kwargs)} kwargs", ) - strict = kwargs.pop("strict", False) + strict = kwargs.pop("strict", ConstantVariable.create(False)) iter_args = [BuiltinVariable(iter).call_function(tx, [arg], {}) for arg in args] return variables.ZipVariable( iter_args, - strict=strict, # type: ignore[arg-type] + strict=strict.as_python_constant(), mutation_type=ValueMutationNew(), ) diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 1e886c6ee7ad7..672fa1d804383 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -8,7 +8,9 @@ import enum import operator -from typing import Any, Literal, Optional, TYPE_CHECKING, Union +from collections.abc import Sequence +from typing import Any, Literal, Optional, overload, TYPE_CHECKING, Union +from typing_extensions import override import torch from torch._dynamo.source import AttrSource, GetItemSource @@ -28,6 +30,8 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + class ConstantVariable(VariableTracker): """ @@ -38,6 +42,14 @@ class ConstantVariable(VariableTracker): nested collections. """ + @overload + @staticmethod + def create(value: bool) -> "ConstantVariable": ... + + @overload + @staticmethod + def create(value: Any, **kwargs: Any) -> VariableTracker: ... + @staticmethod def create(value: Any, **kwargs: Any) -> VariableTracker: """ @@ -53,10 +65,10 @@ def create(value: Any, **kwargs: Any) -> VariableTracker: # Routing for supported collection literals. if isinstance(value, set): items = [ConstantVariable.create(x) for x in value] - return variables.SetVariable(items, **kwargs) + return variables.SetVariable(items, **kwargs) # type: ignore[arg-type] elif isinstance(value, frozenset): items = [ConstantVariable.create(x) for x in value] - return variables.FrozensetVariable(items, **kwargs) + return variables.FrozensetVariable(items, **kwargs) # type: ignore[arg-type] elif isinstance(value, slice): slice_args = (value.start, value.stop, value.step) slice_args_vars = tuple(ConstantVariable.create(arg) for arg in slice_args) @@ -266,9 +278,65 @@ def call_method( ) return super().call_method(tx, name, args, kwargs) + def call_tree_map( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if self.value is None: + none_is_leaf_var = tree_map_kwargs.get("none_is_leaf") + if none_is_leaf_var is not None: + try: + none_is_leaf = bool(none_is_leaf_var.as_python_constant()) + except NotImplementedError: + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + else: + tree_map_module = getattr( + getattr(tree_map_fn, "fn", None), "__module__", "" + ) + # torch.utils._pytree and torch.utils._cxx_pytree treat None as a leaf + # by default, while optree keeps it as an internal node unless + # none_is_leaf=True is provided. + none_is_leaf = not tree_map_module.startswith("optree") + if none_is_leaf: + return map_fn.call_function(tx, [self, *rest], {}) + else: + for other in rest: + if not ( + other.is_python_constant() + and other.as_python_constant() is None + ): + return self._tree_map_fallback( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + return self.clone() + if isinstance(self.value, (int, float, bool, complex, str, bytes, torch.dtype)): + return map_fn.call_function(tx, [self, *rest], {}) + return super().call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + @override def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> "ConstantVariable": result = hasattr(self.value, name) return variables.ConstantVariable.create(result) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 24cd5007da37d..7a74f487ff96c 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -23,7 +23,7 @@ import inspect import operator import types -from collections.abc import Hashable as py_Hashable +from collections.abc import Hashable as py_Hashable, Sequence from typing import Any, Optional, TYPE_CHECKING, Union from torch._subclasses.fake_tensor import is_fake @@ -51,6 +51,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + # [Adding a new supported class within the keys of ConstDictVariable] # - Add its tracker type to is_hashable @@ -316,6 +318,56 @@ def __contains__(self, vt: VariableTracker) -> bool: and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) ) + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "UserFunctionVariable", + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + other_dicts: list[ConstDictVariable] = [] + for candidate in rest: + candidate = candidate.realize() + if not isinstance(candidate, ConstDictVariable) or len( + candidate.items + ) != len(self.items): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_dicts.append(candidate) + + new_items_hashed = type(self.items)() + for key_tracker, value in self.items.items(): + sibling_leaves: list[VariableTracker] = [] + for candidate in other_dicts: + try: + sibling_leaves.append(candidate.items[key_tracker]) + except KeyError: + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + new_items_hashed[key_tracker] = value.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + + updated_original_items = { + key_tracker.vt: new_items_hashed[key_tracker] + for key_tracker in new_items_hashed + } + + return self.clone( + items=new_items_hashed, + original_items=updated_original_items, + should_reconstruct_all=True, + source=None, + mutation_type=ValueMutationNew(), + ) + def len(self) -> int: return sum( not isinstance(x, variables.DeletedVariable) for x in self.items.values() @@ -806,7 +858,7 @@ def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTrack def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( @@ -905,7 +957,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) @@ -1296,13 +1348,6 @@ def install_dict_contains_guard( class FrozensetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "frozenset()" @@ -1360,13 +1405,6 @@ def call_method( class DictKeySetVariable(SetVariable): - def __init__( - self, - items: list[VariableTracker], - **kwargs: Any, - ) -> None: - super().__init__(items, **kwargs) - def debug_repr(self) -> str: if not self.items: return "dict_keys([])" @@ -1446,7 +1484,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index e30eeeb2c2fde..deee9bcec42de 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -29,6 +29,7 @@ import sys import traceback import types +from collections import namedtuple from collections.abc import Callable, Sequence from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar @@ -38,6 +39,7 @@ import torch from torch._dynamo.exc import get_stack_above_dynamo from torch._guards import Source +from torch.utils._pytree import is_namedtuple_class from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -59,10 +61,13 @@ from ..source import ( AttrSource, ClosureSource, + CollectionsSource, ConstantSource, DefaultsSource, GetItemSource, SkipGuardSource, + TorchSource, + TypeSource, ) from ..utils import ( check_constant_args, @@ -109,12 +114,21 @@ _F = TypeVar("_F", bound=Callable[..., Any]) CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 +_SUPPORTED_TREE_MAP_KWARGS = frozenset({"namespace", "none_is_leaf", "is_leaf"}) +_TREE_MAP_ONLY_SUPPORTED_KWARGS = frozenset({"is_leaf"}) # Module-level cache keyed by the function object _spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() +@functools.lru_cache +def get_pytree_SUPPORTED_NODES_source(): + return AttrSource( + AttrSource(AttrSource(TorchSource(), "utils"), "_pytree"), "SUPPORTED_NODES" + ) + + class FunctionSpec: def __init__(self, func: FunctionType): code = func.__code__ @@ -379,7 +393,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: result = False try: @@ -408,6 +422,15 @@ class UserFunctionVariable(BaseUserFunctionVariable): *BaseUserFunctionVariable._nonvar_fields, } + _TREE_MAP_MODULES = frozenset( + { + "optree", + "optree.ops", + "torch.utils._pytree", + "torch.utils._cxx_pytree", + } + ) + @classmethod def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) @@ -543,7 +566,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: result = hasattr(self.fn, name) return variables.ConstantVariable.create(result) @@ -644,8 +667,190 @@ def call_function( ]: with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx): return super().call_function(tx, args, kwargs) + + tree_map_result = self._maybe_call_tree_map_fastpath(tx, args, kwargs) + if tree_map_result is not None: + return tree_map_result + return super().call_function(tx, args, kwargs) + def _maybe_call_tree_map_fastpath( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[VariableTracker]: + rewrite = self._rewrite_tree_map_only_call(tx, args, kwargs) + if rewrite is not None: + tree_map_fn, tree_map_args, tree_map_kwargs = rewrite + else: + tree_map_fn = self + tree_map_args = args + tree_map_kwargs = kwargs + + if not ( + isinstance(tree_map_fn, UserFunctionVariable) + and tree_map_fn._is_tree_map_function() + and not ({*tree_map_kwargs} - _SUPPORTED_TREE_MAP_KWARGS) + and len(tree_map_args) >= 2 + ): + return None + + map_fn = tree_map_args[0] + first_tree = tree_map_args[1] + rest = tree_map_args[2:] + return first_tree.call_tree_map( + tx, + tree_map_fn, + map_fn, + rest, + tree_map_kwargs, + ) + + def _is_tree_map_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _is_tree_map_only_function(self) -> bool: + return ( + getattr(self.fn, "__name__", None) == "tree_map_only" + and getattr(self.fn, "__module__", None) in self._TREE_MAP_MODULES + ) + + def _rewrite_tree_map_only_call( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> Optional[ + tuple[ + "UserFunctionVariable", + Sequence[VariableTracker], + dict[str, VariableTracker], + ] + ]: + if not self._is_tree_map_only_function(): + return None + + if len(args) != 3: + return None + if {*kwargs} - _TREE_MAP_ONLY_SUPPORTED_KWARGS: + return None + + type_selector, map_fn, tree_arg = args + allowed_types = self._extract_tree_map_only_types(type_selector) + if allowed_types is None: + return None + + tree_map_callable = self._lookup_tree_map_function() + if tree_map_callable is None: + return None + + wrapped_map_fn = TreeMapOnlyFunctionVariable( + allowed_types, + map_fn, + source=getattr(map_fn, "source", None), + ) + tree_map_variable = variables.UserFunctionVariable(tree_map_callable) + return tree_map_variable, [wrapped_map_fn, tree_arg], dict(kwargs) + + def _lookup_tree_map_function(self) -> Optional[types.FunctionType]: + module_name = getattr(self.fn, "__module__", None) + if not module_name: + return None + module = sys.modules.get(module_name) + if module is None: + return None + tree_map = getattr(module, "tree_map", None) + if isinstance(tree_map, types.FunctionType): + return tree_map + return None + + def _extract_tree_map_only_types( + self, selector: VariableTracker + ) -> Optional[tuple[type, ...]]: + if not selector.is_python_constant(): + return None + try: + raw_value = selector.as_python_constant() + except NotImplementedError: + return None + + flattened = self._flatten_type_spec(raw_value) + if not flattened: + return None + if not all(isinstance(typ, type) for typ in flattened): + return None + return tuple(dict.fromkeys(flattened)) + + def _flatten_type_spec(self, value: Any) -> Optional[list[type]]: + if isinstance(value, type): + return [value] + if isinstance(value, tuple): + collected: list[type] = [] + for entry in value: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + union_type = getattr(types, "UnionType", None) + if union_type is not None and isinstance(value, union_type): + collected = [] + for entry in value.__args__: + flat = self._flatten_type_spec(entry) + if flat is None: + return None + collected.extend(flat) + return collected + return None + + +class TreeMapOnlyFunctionVariable(BaseUserFunctionVariable): + _nonvar_fields = { + "allowed_types", + *BaseUserFunctionVariable._nonvar_fields, + } + + def __init__( + self, + allowed_types: tuple[type, ...], + map_fn: VariableTracker, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.allowed_types = allowed_types + self.map_fn = map_fn + + def python_type(self) -> type: + return FunctionType + + def _matches_allowed_type(self, node: VariableTracker) -> bool: + try: + node_type = node.python_type() + except NotImplementedError: + return False + return any(issubclass(node_type, allowed) for allowed in self.allowed_types) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not args: + return self.map_fn.call_function(tx, args, kwargs) + leaf = args[0] + if self._matches_allowed_type(leaf): + return self.map_fn.call_function(tx, args, kwargs) + if len(args) != 1 or kwargs: + # Defer to the original map function so we fall back to normal + # tracing instead of triggering a graph break. + return self.map_fn.call_function(tx, args, kwargs) + return leaf + class BuiltinMethodVariable(BaseUserFunctionVariable): def __init__( @@ -780,7 +985,7 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) @@ -1432,7 +1637,7 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if name == "__code__": return variables.ConstantVariable.create(hasattr(self, "code")) if name == "__defaults__": @@ -1749,7 +1954,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(self.value, name)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -2109,7 +2314,7 @@ def call_function( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: # functools.partial uses slots, so attributes are constant return variables.ConstantVariable.create( hasattr(functools.partial(identity), name) @@ -2717,3 +2922,97 @@ def call_function( tensor=tensor, # type: ignore[arg-type] block_shape=block_shape, # type: ignore[arg-type] ) + + +class PyTreeGetNodeTypeFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree._get_node_type` function is very hot function. We want to special case it to reduce Dynamo tracing time. + + def _get_node_type(tree: Any) -> Any: + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): + return namedtuple + return node_type + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if len(args) != 1: + raise_type_error_exc( + tx, + f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + ) + type_source = None + if args[0].source: + install_guard(args[0].source.make_guard(GuardBuilder.TYPE_MATCH)) + type_source = TypeSource(args[0].source) + python_type = args[0].python_type() + if is_namedtuple_class(python_type): + type_source = AttrSource(CollectionsSource(), "namedtuple") + return VariableTracker.build(tx, namedtuple, type_source) + return VariableTracker.build(tx, python_type, source=type_source) + + +class PyTreeTreeIsLeafFunctionVariable(UserFunctionVariable): + """ + `torch.utils._pytree.tree_is_leaf` function is a hot function. We want to special case it to reduce Dynamo tracing time. + + def tree_is_leaf( + tree: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + ) -> bool: + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + When is_leaf is None (the common case), we can optimize by not tracing into the function. + When is_leaf is not None, we fall back to regular tracing since it requires executing user code. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # tree_is_leaf(tree, is_leaf=None) + if len(args) < 1 or len(args) > 2: + raise_type_error_exc( + tx, + f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + ) + + # Check if is_leaf parameter is provided + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) + if len(args) == 2: + is_leaf = args[1] + + if not ( + isinstance(is_leaf, variables.ConstantVariable) and is_leaf.value is None + ): + return super().call_function(tx, args, kwargs) + + # Optimize the case where is_leaf is None + # return _get_node_type(tree) not in SUPPORTED_NODES + tree = args[0] + node_type_var = PyTreeGetNodeTypeFunctionVariable( + torch.utils._pytree._get_node_type + ).call_function(tx, [tree], {}) + + # If the SUPPORTED_NODES was seen earlier and mutated, there would be a + # source and that will give us the mutated SUPPORTED_NODES. + supported_nodes_var = VariableTracker.build( + tx, + torch.utils._pytree.SUPPORTED_NODES, + source=get_pytree_SUPPORTED_NODES_source(), + ) + out = supported_nodes_var.call_method(tx, "__contains__", [node_type_var], {}) + return ConstantVariable.create(not out.value) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 4ae8868f15e84..afb6522ac0e5c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3159,7 +3159,7 @@ def _call_function( # to (body_node, lifted_args_tuple, {}) body_node = p_args[0] lifted_args = p_args[1:] - p_args = (body_node, lifted_args, {}) + p_args = (body_node, tuple(lifted_args), {}) # add hints into p_kwargs p_kwargs = {} diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 162ec02a9a9b7..c111dca9f2d68 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -262,7 +262,7 @@ def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": if name == "__iter__" or name == "__next__": return variables.ConstantVariable.create(True) return super().call_obj_hasattr(tx, name) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 2ac355bd53417..05129fcf8fb45 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -140,6 +140,50 @@ def getitem_const( def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return list(self.items) + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not isinstance(self, (ListVariable, TupleVariable)): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + + other_lists: list[BaseListVariable] = [] + for candidate in rest: + if ( + not isinstance(candidate, BaseListVariable) + or len(candidate.items) != len(self.items) + or self.python_type() != candidate.python_type() + ): + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + other_lists.append(candidate) + + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [candidate.items[idx] for candidate in other_lists] + new_items.append( + item.call_tree_map( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + ) + ) + + return self.clone( + items=new_items, + source=None, + mutation_type=ValueMutationNew(), + ) + def call_method( self, tx: "InstructionTranslator", @@ -480,7 +524,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is range: return variables.ConstantVariable.create(name in range.__dict__) return super().call_obj_hasattr(tx, name) @@ -932,7 +976,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is not list: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) @@ -1089,7 +1133,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is collections.deque: return variables.ConstantVariable.create(name in collections.deque.__dict__) return super().call_obj_hasattr(tx, name) @@ -1130,7 +1174,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is not tuple: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) @@ -1292,7 +1336,7 @@ def get_item_dyn( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(torch.Size, name)) @@ -1540,7 +1584,7 @@ def check_and_create_method() -> Optional[VariableTracker]: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create( name in self.dynamic_attributes or hasattr(self.tuple_cls, name) ) @@ -1653,7 +1697,7 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: return variables.ConstantVariable.create(hasattr(iter([]), name)) def python_type(self) -> type: @@ -1726,7 +1770,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + ) -> ConstantVariable: if self.python_type() is range_iterator: ri = iter(range(0)) return ConstantVariable(hasattr(ri, name)) diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index e754699d862ad..4b5198ffe8533 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -72,6 +72,8 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs): """ @@ -230,7 +232,7 @@ def unpack_var_sequence(self, tx): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": mod = tx.output.get_submodule(self.module_key) result = hasattr(mod, name) install_guard( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 326178ef00874..0787ef7c49b57 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -83,6 +83,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .functions import UserFunctionVariable + log = logging.getLogger(__name__) @@ -612,6 +614,16 @@ def unpack_var_sequence(self, tx: "InstructionTranslator", idxes=None): for i in idxes ] + def call_tree_map( + self, + tx, + tree_map_fn: "UserFunctionVariable", + map_fn, + rest, + tree_map_kwargs, + ) -> "VariableTracker": + return map_fn.call_function(tx, [self, *rest], {}) + def valid_size(self): return self._size is not None @@ -1266,6 +1278,19 @@ def method_to_local(self, *args, **kwargs): tx = InstructionTranslator.current_tx() # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function + + grad_placements_vt = kwargs.get( + "grad_placements", ConstantVariable.create(None) + ) + if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable): + # grad_placement is a sequence-like structure, iterate over the value + grad_placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [grad_placements_vt], {} + ) + + if kwargs.get("grad_placements") is not None: + kwargs["grad_placements"] = grad_placements_vt + args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ec378a5512a01..fb676295535df 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -114,6 +114,8 @@ from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator + from .constant import ConstantVariable + def is_standard_setattr(val): return val in (object.__setattr__, BaseException.__setattr__) @@ -913,7 +915,7 @@ def is_standard_new(self): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> "ConstantVariable": if self.source: source = AttrSource(self.source, name) install_guard(source.make_guard(GuardBuilder.HASATTR)) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index d653db0c23a74..7a0a87dab19dc 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -79,7 +79,7 @@ def aot_compile_warning(): def aot_compile( f: Callable, - args: tuple[Any], + args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None, *, dynamic_shapes: Optional[dict[str, Any]] = None, diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index f4a08f8739993..155f52595740c 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<0e870e558fb4362f69b825842ab606cf0becd10a008003ac676156becf20b65b>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -167,6 +167,8 @@ union Argument { 240: list as_sym_floats; 250: OptionalTensorArgument as_optional_tensor; 260: ComplexValue as_complex; + 280: list> as_int_lists; + 290: map as_string_to_argument; } struct NamedArgument { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index a9cec8b185c58..0d95ca32e6455 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 14) +SCHEMA_VERSION = (8, 15) TREESPEC_VERSION = 1 @@ -212,6 +212,8 @@ class Argument(_Union): as_sym_floats: Annotated[list[SymFloatArgument], 240] as_optional_tensor: Annotated[OptionalTensorArgument, 250] as_complex: Annotated[ComplexValue, 260] + as_int_lists: Annotated[list[list[int]], 280] + as_string_to_argument: Annotated[dict[str, "Argument"], 290] class ArgumentKind(IntEnum): diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 951351e7786aa..6f13741416cb3 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +# checksum<> AOTInductorModelPickleData: kind: struct fields: @@ -75,6 +75,10 @@ Argument: type: OptionalTensorArgument as_complex: type: ComplexValue + as_int_lists: + type: List[List[int]] + as_string_to_argument: + type: Dict[str, Argument] ArgumentKind: kind: enum fields: @@ -551,5 +555,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 14 +- 15 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 84978f0066712..c64aaff9ae1f2 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -39,6 +39,7 @@ from torch.utils._triton import has_triton from ..utils import remove_proxy_from_state_dict +from . import schema from .schema import ( # type: ignore[attr-defined] Argument, ArgumentKind, @@ -1195,6 +1196,13 @@ def serialize_input(self, arg, arg_type: Optional[Any] = None) -> Argument: ) elif arg is None: return Argument.create(as_none=True) + elif isinstance(arg, dict): + serialized_dict = {} + for key, value in arg.items(): + if not isinstance(key, str): + raise SerializeError(f"Dict keys must be strings, got {type(key)}") + serialized_dict[key] = self.serialize_input(value) + return Argument.create(as_string_to_argument=serialized_dict) elif isinstance(arg, (list, tuple)): if len(arg) == 0: if arg_type is not None: @@ -1326,6 +1334,11 @@ def serialize_optional_tensor_args(a): return Argument.create( as_optional_tensors=list(map(serialize_optional_tensor_args, arg)) ) + elif all( + isinstance(a, tuple) and all(type(x) is int for x in a) for a in arg + ): + # list of int tuples + return Argument.create(as_int_lists=[list(t) for t in arg]) else: raise SerializeError( f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" @@ -2735,6 +2748,12 @@ def deserialize_input(self, inp: Argument) -> Any: return self.deserialize_sym_argument(inp.as_sym_float) elif typ_ == "as_sym_bool": return self.deserialize_sym_argument(inp.as_sym_bool) + elif isinstance(value, dict): + if typ_ == "as_string_to_argument": + # Deserialize dict[str, Argument] recursively + return {k: self.deserialize_input(v) for k, v in value.items()} + else: + raise SerializeError(f"Unknown dict type: {typ_}") elif isinstance(value, list): if len(value) == 0: return [] @@ -2744,6 +2763,9 @@ def deserialize_input(self, inp: Argument) -> Any: elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"): # convert from serialized.python.types.List to python list return list(value) + elif typ_ == "as_int_lists": + # Convert list of lists back to list of tuples for Triton grids + return [tuple(dims) for dims in value] elif typ_ in ("as_sym_ints", "as_sym_bools", "as_sym_floats"): return [self.deserialize_sym_argument(arg) for arg in value] elif typ_ == "as_optional_tensors": @@ -3239,7 +3261,18 @@ def serialize( return artifact +def _resolve_schema_cls(cls): + if isinstance(cls, str): + resolved = getattr(schema, cls, None) + if resolved is not None: + return resolved + if isinstance(cls, typing.ForwardRef): + return _resolve_schema_cls(cls.__forward_arg__) + return cls + + def _dict_to_dataclass(cls, data): + cls = _resolve_schema_cls(cls) assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." if typing.get_origin(cls) is Annotated: return _dict_to_dataclass(cls.__origin__, data) @@ -3255,12 +3288,13 @@ def _dict_to_dataclass(cls, data): _type = next(iter(data.keys())) _value = next(iter(data.values())) assert isinstance(_type, str) - field_type = cls.__annotations__[_type] + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) + field_type = type_hints[_type] # pyrefly: ignore [missing-attribute] return cls.create(**{_type: _dict_to_dataclass(field_type, _value)}) elif dataclasses.is_dataclass(cls): fields = {} - type_hints = typing.get_type_hints(cls) + type_hints = typing.get_type_hints(cls, globalns=vars(schema)) # For forward compatibility consideration, we ignore all the keys # that are not showing up in the dataclass definition. for f in dataclasses.fields(cls): @@ -3365,6 +3399,10 @@ def _get_argument(a: Argument): return a.as_custom_obj elif a.type == "as_operator": return None + elif a.type == "as_int_lists": + return None + elif a.type == "as_string_to_argument": + return None else: raise AssertionError(f"Unknown input type to the ExportedProgram: {a}") diff --git a/torch/_functorch/_activation_checkpointing/knapsack.py b/torch/_functorch/_activation_checkpointing/knapsack.py index 0a3eaa5a9344c..b2f0a124c64c1 100644 --- a/torch/_functorch/_activation_checkpointing/knapsack.py +++ b/torch/_functorch/_activation_checkpointing/knapsack.py @@ -119,3 +119,149 @@ def dp_knapsack( max_runtime = dp[n][quantized_max_memory].item() return max_runtime, saved_items, recomputable_items + + +def dp_knapsack_sliding_hirschberg( + memory: list[float], runtime: list[float], max_memory: float +) -> tuple[float, list[int], list[int]]: + # Scaling factor to convert floating point weights to integers + S = 10000 + + # q_ prefix stands for quantized + q_memory = [int(round(m * S)) for m in memory] + runtimes = [float(v) for v in runtime] + + q_max_memory = int(round(max_memory * S)) + + q_memory_length = len(q_memory) + if q_memory_length == 0: + return 0.0, [], [] + + item_indices = list(range(q_memory_length)) + dp_profile_size = q_max_memory + 1 + + # Current DP profile (row) + dp_profile = torch.zeros(dp_profile_size, dtype=torch.float32, device="cpu") + # Store a candidate for next dp_profile - current dp row + item + candidate_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + left_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + right_profile = torch.empty(dp_profile_size, dtype=torch.float32, device="cpu") + + saved_items: list[int] = [] + recomputable_items: list[int] = [] + + # Explicit stack to optimize memory and avoid recursion + # Stack stores segments as (start index, end index, capacity for segment) + stack: list[tuple[int, int, int]] = [(0, q_memory_length, q_max_memory)] + + # LIFO + while stack: + start, end, capacity = stack.pop() + length = end - start + if length == 0: + continue + + # Leaf + if length == 1: + index = item_indices[start] + memory_item = q_memory[index] + runtime_item = runtimes[index] + if memory_item <= capacity and runtime_item > 0.0: + saved_items.append(index) + else: + recomputable_items.append(index) + continue + + # Split the segment into two halves + middle = start + (length // 2) + left_start, left_end = middle, end + right_start, right_end = start, middle + + # Assign items to both halves + left_items = item_indices[left_start:left_end] + right_items = item_indices[right_start:right_end] + + # Working only on items allowed by segment's capacity + capacity = capacity + 1 + dp_view = dp_profile[:capacity] + candidate_view = candidate_profile[:capacity] + left_dp_local = left_profile[:capacity] + right_dp_local = right_profile[:capacity] + + # Left part + dp_view.zero_() + for index in left_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + # Weight is 0, so add it to all capacities; a "free lunch", essentially + dp_view.add_(runtime_item) + continue + + # If item is too heavy, we skip it + if memory_item >= capacity: + continue + + # Add the current item so we can then pick the highest value + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + # Take the highest - either previous (without current) or with current + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the left profile + left_dp_local.copy_(dp_view) + + # Right part + dp_view.zero_() + for index in right_items: + memory_item = q_memory[index] + runtime_item = runtimes[index] + + if memory_item == 0: + dp_view.add_(runtime_item) + continue + + if memory_item >= capacity: + continue + + dp_view_candidate = candidate_view[: capacity - memory_item] + torch.add(dp_view[:-memory_item], runtime_item, out=dp_view_candidate) + torch.maximum( + dp_view[memory_item:], dp_view_candidate, out=dp_view[memory_item:] + ) + + # Store the reversed right profile + right_dp_local.copy_(dp_view.flip(-1)) + + # In-place compute item-wise sum of left and right to pick the split point where the sum is highest + left_dp_local.add_(right_dp_local) + + # Pick the index of highest value of a pair, which we then use as a split point + best_split = int(torch.argmax(left_dp_local).item()) + + left_capacity = best_split + right_capacity = capacity - best_split + + # Clamp (might be removed if we're 100% sure that there is no edge case that will mess up the indices math) + if left_capacity < 0: + left_capacity = 0 + if right_capacity < 0: + right_capacity = 0 + if left_capacity > q_max_memory: + left_capacity = q_max_memory + if right_capacity > q_max_memory: + right_capacity = q_max_memory + + # Push right then left, so left is processed next + stack.append((right_start, right_end, right_capacity)) + stack.append((left_start, left_end, left_capacity)) + + saved_items = sorted(saved_items) + recomputable_items = sorted(recomputable_items) + + max_runtime = sum(runtime[i] for i in saved_items) + recomputable_items.reverse() + return max_runtime, saved_items, recomputable_items diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index fcbf861e537db..5af4fc9ee1195 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -10,6 +10,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Optional import torch from torch import Tensor @@ -449,7 +450,7 @@ def was_tensor_metadata_updated(arg, new_arg): # Returns the number of detected copy_ -def assert_functional_graph(fx_g: torch.fx.Graph) -> int: +def _is_functional_graph(fx_g: torch.fx.Graph) -> tuple[Optional[str], int]: allowed_mutation_ops = [ torch.ops.aten.copy_.default, torch.ops.aten.set_.source_Tensor, @@ -462,6 +463,7 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # NB: It would also be nice to verify that the mutations all happen at the # end, but we also do some administrative views after mutations so this # isn't actually true. (TODO: Could this cause problems for Inductor?) + error = None for n in fx_g.nodes: if n.op == "placeholder": placeholders.add(n) @@ -471,14 +473,18 @@ def assert_functional_graph(fx_g: torch.fx.Graph) -> int: # this is mostly a hack to avoid failing XLA tests. # See https://github.com/pytorch/pytorch/pull/122434#issuecomment-2101012113 if "set_buffer_donor_" not in str(n.args[0]): - assert n.args[0] in placeholders, ( - f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" - ) + if n.args[0] not in placeholders: + error = f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}" mutation_count += 1 else: - assert not n.target._schema.is_mutable, ( - f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" - ) + if n.target._schema.is_mutable: + error = f"aot_autograd expected to have an entirely functional graph, but found {n.format_node()}" + return error, mutation_count + + +def assert_functional_graph(fx_g: torch.fx.Graph) -> int: + error, mutation_count = _is_functional_graph(fx_g) + assert error is None, error return mutation_count diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index b6ea08a802240..f17a516183975 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -33,7 +33,7 @@ handle_effect_tokens_fn, ) from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta -from .streams import assign_backward_streams +from .streams import assign_backward_streams, insert_backward_syncs from .utils import ( call_and_expect_output_descs, copy_fwd_metadata_to_bw_nodes, @@ -477,6 +477,8 @@ def aot_dispatch_autograd_graph( # After copying metadata, assign streams to gradient accumulation nodes assign_backward_streams(fx_g) + insert_backward_syncs(fx_g) + fx_g.graph.eliminate_dead_code() if not aot_config.disable_functionalization: # There should be *NO* mutating ops in the graph at this point. diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index bc4dc87ddeced..2ef84cb488604 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -27,6 +27,7 @@ from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -295,6 +296,10 @@ def inner_fn( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( fn, primals ) + mode = get_proxy_mode() + assert mode is not None, "Expected non-None proxy mode" + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" # TODO: I think this hook can also be eliminated now if joint_fn_handle and joint_fn_handle.post_forward: diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index f78a2c6cad1de..1fc8a965740fd 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -3,15 +3,25 @@ import torch.fx import torch.fx.traceback from torch._dynamo.graph_utils import _get_flat_args +from torch._dynamo.variables.streams import get_current_stream, new_event Node: TypeAlias = torch.fx.Node +Graph: TypeAlias = torch.fx.Graph def is_gradient_acc(node: Node) -> bool: return node.meta.get("is_gradient_acc", False) +def is_bwd_node(node: Node) -> bool: + return node.meta.get("partitioner_tag") == "is_backward" + + +def get_device(node: Node) -> torch.device: + return node.meta["val"].device + + def get_stream(node: Node) -> Optional[int]: maybe_annotation = node.meta.get("custom", None) if maybe_annotation is not None: @@ -20,6 +30,13 @@ def get_stream(node: Node) -> Optional[int]: return None +def get_stream_or_current_stream(node: Node) -> int: + ind = get_stream(node) + if ind is None: + ind = get_current_stream(get_device(node)) + return ind + + def set_stream(node: Node, ind: int) -> None: if "custom" in node.meta: node.meta["custom"].update({"stream": ind}) @@ -27,6 +44,45 @@ def set_stream(node: Node, ind: int) -> None: node.meta["custom"] = {"stream": ind} +def insert_record_event_after_node(graph: Graph, node: Node, event_ind: int) -> None: + with graph.inserting_after(node): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + +def insert_wait_event_before_node(graph: Graph, node: Node, event_ind: int) -> None: + with graph.inserting_before(node): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + event_ind, + get_stream_or_current_stream(node), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + +def insert_sync( + graph: Graph, + consumer: Node, + producer: Node, + node_to_wait_event_ind: dict[Node, int], +) -> None: + if producer not in node_to_wait_event_ind: + node_to_wait_event_ind[producer] = new_event() + + insert_record_event_after_node( + graph, producer, node_to_wait_event_ind[producer] + ) + insert_wait_event_before_node(graph, consumer, node_to_wait_event_ind[producer]) + + def assign_backward_streams(gm: torch.fx.GraphModule) -> None: """Assigns backward streams to gradient accumulation nodes""" @@ -51,3 +107,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None: if ind is not None: set_stream(node, ind) break + + +def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: + """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" + node_to_wait_event_ind = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + flat_args = _get_flat_args(node, {}) + cur_node_stream = get_stream(node) + + for arg in flat_args: + if is_bwd_node(arg): + arg_stream = get_stream(arg) + if arg_stream != cur_node_stream and get_device(arg).type != "cpu": + insert_sync(gm.graph, node, arg, node_to_wait_event_ind) diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 790cf71a83a23..42d6f308f831a 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -162,8 +162,9 @@ def remote_autograd_cache_default() -> Optional[bool]: activation_memory_budget_runtime_estimator = "flops" # This controls the solver used for the 0-1 knapsack. By default we use a -# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp" -# (which has a scipy dependency). +# quantized DP solution ("dp"). The other approaches are a "greedy", a "ilp" +# (which has a scipy dependency) and "dp_knapsack_sliding_hirschberg", which +# used memory-efficient quantized DP solution activation_memory_budget_solver = "dp" # This dumps out a SVG visualization of the expected runtime vs. activation diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index e7f8075b0281e..3b79a50ff9e21 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -10,6 +10,7 @@ import os import os.path import re +import warnings from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass, replace @@ -46,11 +47,13 @@ from ._activation_checkpointing.graph_info_provider import GraphInfoProvider from ._activation_checkpointing.knapsack import ( dp_knapsack, + dp_knapsack_sliding_hirschberg, greedy_knapsack, ilp_knapsack, ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput +from ._aot_autograd.functional_utils import _is_functional_graph from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -297,6 +300,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +def _has_tag_is_forward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_forward" + + def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" @@ -1021,105 +1028,136 @@ def default_partition( Returns: Returns the generated forward and backward Fx graph modules. """ - if has_recomputable_ops(joint_module): - return min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) + # Respect the original placement of ops rather than rely on dataflow. + forward_nodes = [] + last_node = None + for node in joint_module.graph.nodes: + if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node): + last_node = node + assert last_node is not None + for node in joint_module.graph.nodes: + if not _is_tangent(node): + forward_nodes.append(node) + if node is last_node: + break forward_node_names = OrderedSet( - node.name for node in forward_only_graph.nodes if node.op != "output" + node.name for node in forward_nodes if node.op != "output" + ) + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + if _is_functional_graph(joint_module.graph)[0] is not None: + # Fall-back to previous behavior to avoid bc-breaking, although can + # eventually flip the switch to make this a hard error. + warnings.warn( + "Trying to unsafely apply AC to a non-functional graph with the " + "default partitioner. Falling back to min-cut partitioner." + ) + return min_cut_rematerialization_partition( + joint_module, + _joint_inputs, + num_fwd_outputs=num_fwd_outputs, + static_lifetime_input_indices=static_lifetime_input_indices, + ) + + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(joint_module) + + force_save_bw_mutation_src(joint_module) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs ) - order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + saved_values = [] saved_sym_nodes = [] - def is_mutated_later_in_fw(node): - if _has_tag_is_backward(node): - return False - tensor_arg_aliases = [ - x - for x in node.args - if isinstance(x, fx.Node) - and "val" in x.meta - and isinstance(x.meta["val"], torch.Tensor) - ] - while len(tensor_arg_aliases) > 0: - a = tensor_arg_aliases.pop() - for u in a.users: - if not isinstance(u.target, torch._ops.OpOverload): - continue - # If we witness a mutation on our node later, and that mutation is not "must be in backward", - # then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) - if ( - # one of the args was mutated - u.target._schema.is_mutable - # and the mutation happens "later" - and order[u] > order[node] - # and the mutation happened during the forward - and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) - ): - for idx, alias_info in enumerate(u.target._schema.arguments): - if alias_info.is_write and u.args[idx] is a: - return True - elif u.target.is_view: - tensor_arg_aliases.append(u) - return False + distributed_enabled = torch.distributed.is_available() + + def is_tensor(node): + return "tensor_meta" in node.meta or isinstance( + node.meta.get("val"), torch._subclasses.FakeTensor + ) + + def is_multi_output(node): + return ( + all(user.target == operator.getitem for user in node.users) + and len(node.users) > 0 + ) + + def is_impure(node): + # wait tensor is an "impure" op according to DCE's definition of impure + # (see is_impure in torch/fx/node.py), but it survives past + # functionalization and can be safely dup'd and reordered under the + # assumption SPMD. + return ( + node.is_impure(impure_random=False) + and node.op + not in ( + "placeholder", + "output", + ) + and ( + not distributed_enabled + or node.target is not torch.ops._c10d_functional.wait_tensor.default + ) + ) for node in joint_module.graph.nodes: if node.name not in forward_node_names: - # if a node isn't "required" to be in the forward, but any of its arguments - # are later mutated in the forward, then it must have been run in the forward - # (if not, and the node's arg was saved for backward, we would have mutated a saved value) - # NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated - if is_mutated_later_in_fw(node): - saved_values.append(node) + continue + if node.target is torch.ops.aten._assert_scalar.default: continue if is_sym_node(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes.append(node) - elif ( - "tensor_meta" not in node.meta - and node.op == "call_function" - and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) - ): - # Since we can't save tuple of tensor values, we need to flatten out what we're saving - users = node.users - assert all(user.target is operator.getitem for user in users) - saved_values.extend(users) - else: - backward_usages = [ - n for n in node.users if n.name not in forward_node_names - ] - if "tensor_meta" in node.meta and all( - is_sym_node(n) for n in backward_usages - ): - # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, - # and not the actual tensor data, - # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. - # - # Note that saving the tensor could also cause compilation problems: - # If the user mutated an input in the forward and uses its sizes/strides in the backward, - # then we would be obligated to clone the input before saving it to appease autograd. - # (This is how we originally found this bug). - saved_sym_nodes.extend(backward_usages) - else: - saved_values.append(node) + continue + if is_multi_output(node): + # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE. + continue + if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: + saved_values.append(node) + continue + if is_impure(node): + assert not graph_has_recomputable_ops, ( + "Trying to apply AC on a graph with impure op", + node, + node.target, + ) + saved_values.append(node) + continue + assert is_tensor(node) or node.op != "call_function", ( + f"Expected {node} to be a tensor" + ) + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + continue + if not must_recompute(node): + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - return _extract_fwd_bwd_modules( + if config._sync_decision_cross_ranks: + saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) + + if static_lifetime_input_nodes is None: + static_lifetime_input_nodes = node_info.static_lifetime_input_nodes + fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, @@ -1127,6 +1165,37 @@ def is_mutated_later_in_fw(node): static_lifetime_input_nodes=static_lifetime_input_nodes, ) + # Run DCE while overriding the definition of is_impure_node + def is_not_collective(node): + if not distributed_enabled: + return True + if node.target is torch.ops._c10d_functional.wait_tensor.default: + return False + if node.target is torch.ops._c10d_functional.all_gather_into_tensor.default: + return False + return True + + fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective) + + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + if len(node_info.required_bw_nodes) > 0: + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + INT_INF = int(1e6) @@ -1621,7 +1690,16 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break -def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: +def is_getitem_of_multi_output(node): + if node.target != operator.getitem: + return False + parent = node.args[0] + return "tensor_meta" not in parent.meta and node.op == "call_function" + + +def cleanup_recompute_tags( + joint_module: fx.GraphModule, *, is_default_partition: bool +) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -1658,6 +1736,20 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # in forward graph outputs. With this, we can break the above circular dependency. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + elif ( + "ac_graph_id" not in node.meta + and any(must_recompute(user) for user in node.users) + and not ( + # Avoid saving getitem nodes which are not labeled with "ac_graph_id" + is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta + ) + and is_default_partition + ): + # This node is not part of the AC region and a user is marked as recompute. + # This means it's an input to the AC region and we should save it. + # For ease of landing, gate this to default partitioner only, but we should think + # about flipping the switch in general as well. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -2274,6 +2366,8 @@ def _optimize_runtime_with_given_memory( return ilp_knapsack(memory, runtimes, max_memory) elif SOLVER == "dp": return dp_knapsack(memory, runtimes, max_memory) + elif SOLVER == "dp_knapsack_sliding_hirschberg": + return dp_knapsack_sliding_hirschberg(memory, runtimes, max_memory) elif SOLVER == "dynamic_memory_budget_dp": log.warning( "dynamic_memory_budget_dp is an experimental solver. " @@ -2765,6 +2859,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module +def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + + def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2813,68 +2960,16 @@ def min_cut_rematerialization_partition( graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) force_save_bw_mutation_src(joint_module) - def classify_nodes(joint_module, static_lifetime_input_indices): - name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - elif _must_be_in_backward(node): - required_bw_nodes.add(node) - - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) - required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - ) - unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - ) - static_lifetime_input_nodes = OrderedSet( - p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices - ) - fw_cnt = 0 - fw_order = {} - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - fw_order[node] = fw_cnt - fw_cnt += 1 - return NodeInfo( - inputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - fw_order, - static_lifetime_input_nodes, - ) - if static_lifetime_input_indices is None: static_lifetime_input_indices = [] - node_info = classify_nodes(joint_module, static_lifetime_input_indices) + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) # networkx blows up on graphs with no required backward nodes # Since there's nothing to partition anyway, and the default partitioner can "handle" diff --git a/torch/_higher_order_ops/_invoke_quant.py b/torch/_higher_order_ops/_invoke_quant.py index 1fc1e1114a036..b7a9fb94b93e2 100644 --- a/torch/_higher_order_ops/_invoke_quant.py +++ b/torch/_higher_order_ops/_invoke_quant.py @@ -26,9 +26,6 @@ class InvokeQuantUnpacked(BaseHOP): def __init__(self) -> None: super().__init__("invoke_quant") - def __call__(self, subgraph, *operands, scheme=None): - return super().__call__(subgraph, *operands, scheme=scheme) - invoke_quant = InvokeQuantUnpacked() diff --git a/torch/_higher_order_ops/hints_wrap.py b/torch/_higher_order_ops/hints_wrap.py index 3f21c518cbd74..583623393a0a1 100644 --- a/torch/_higher_order_ops/hints_wrap.py +++ b/torch/_higher_order_ops/hints_wrap.py @@ -34,7 +34,7 @@ def __call__(self, body_fn, args, kwargs, hints): backend compiler. """ if not isinstance(args, tuple): - raise RuntimeError(f"args must be a tuple, got {type(args)}") + args = tuple(args) if not all(isinstance(t, (torch.Tensor, int, float, bool)) for t in args): raise RuntimeError( diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 0e398897a7eab..628c889f6cbc7 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -382,14 +382,14 @@ def _get_specialization(args): # type: ignore[no-untyped-def] try: # Latest versions of Triton take specialize_extra as an arg to create_specialize_impl specialize_impl = triton.runtime.jit.create_specialize_impl( - specialize_extra=backend.get_arg_specialization + specialize_extra=backend.get_arg_specialization # pyrefly: ignore [missing-attribute] ) except TypeError: # Unknown arg `specialize_extra` # Older versions of Triton take specialize_extra as an arg to specialize_impl specialize_impl = functools.partial( # pyrefly: ignore # missing-argument triton.runtime.jit.create_specialize_impl(), - specialize_extra=backend.get_arg_specialization, + specialize_extra=backend.get_arg_specialization, # pyrefly: ignore [missing-attribute] ) # create_specialize_impl is removed in https://github.com/triton-lang/triton/pull/7771 # switch to native_specialize_impl instead @@ -413,7 +413,7 @@ def _native_specialize_impl( specialize_impl = functools.partial( specialize_impl_orig, - specialize_extra=backend.get_arg_specialization, + specialize_extra=backend.get_arg_specialization, # pyrefly: ignore [missing-attribute] ) from triton._utils import find_paths_if, get_iterable_path diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 810649e7b7b25..8e6fde9280c4a 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -272,7 +272,7 @@ def aoti_load_package( def aot_compile( gm: torch.fx.GraphModule, - args: tuple[Any], + args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None, *, options: Optional[dict[str, Any]] = None, diff --git a/torch/_inductor/autoheuristic/learnedheuristic_interface.py b/torch/_inductor/autoheuristic/learnedheuristic_interface.py index cb2568d8a6801..84a941b076c31 100644 --- a/torch/_inductor/autoheuristic/learnedheuristic_interface.py +++ b/torch/_inductor/autoheuristic/learnedheuristic_interface.py @@ -39,9 +39,6 @@ def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]: class LearnedHeuristicRegression(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_feedback(self, context: AHContext, choice: Choice) -> float: return 1.0 @@ -64,9 +61,6 @@ def get_decision( class LearnedHeuristicDecision(LearnedHeuristic): - def __init__(self) -> None: - super().__init__() - def get_choice(self, idx: int) -> Optional[str]: return None diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 47542cb6aef77..a5379219a2373 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -561,9 +561,11 @@ def can_fuse_horizontal( shared_data_score: int, ) -> bool: """Hook for heuristics to prevent horizontal (consumer/consumer) fusions""" - if ( - shared_data_score < config.score_fusion_memory_threshold - ) and not MixOrderReduction.can_fuse(node1, node2): + if MixOrderReduction.can_fuse(node1, node2): + # For mix order reduction, we disregard shared data or + # distance. + return True + if shared_data_score < config.score_fusion_memory_threshold: WhyNoFuse(node1, node2)("score_fusion_memory_threshold") return False if scheduler.are_long_distant_nodes(node1, node2): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 88f203421cc1c..18b209de94cb3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3786,9 +3786,6 @@ class TilingSelect: In the future, we can implement advanced heuristic in a subclass. """ - def __init__(self): - super().__init__() - def select_tiling( self, fn_list, diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 61a97fd740cbc..3a65d1c895d1c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -158,11 +158,12 @@ def _generate_kernel_call_helper( ) new_args = [] for idx, arg in enumerate(call_args): - if "*" in arg_types[idx]: + if isinstance(arg_types[idx], str) and "*" in arg_types[idx]: new_args.append(f"({arg_types[idx]})({arg}.data_ptr())") else: - # arg is a scalar - new_args.append(arg) + # arg is a scalar - ensure it's a string for C++ codegen + # With Triton support, arg might be a SymPy expression or other type + new_args.append(str(arg) if not isinstance(arg, str) else arg) # debug printer related logic for cpp kernel type. debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager.set_printer_args( @@ -2561,13 +2562,13 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: codegen_arg = codegen_arg.removeprefix("&") if codegen_arg == "nullptr": - return "from(std::nullopt)" + return "torch::stable::detail::from(std::nullopt)" var_name = f"tmp_var_{next(tmp_var_number)}" dispatch_lines.writeline( f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" ) - return f"from({var_name})" + return f"torch::stable::detail::from({var_name})" raii_var = self.create_tmp_raii_handle_var_if_needed( codegen_arg, dispatch_lines @@ -2584,11 +2585,11 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: dispatch_lines.writeline( f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" ) - return f"from({var_name})" + return f"torch::stable::detail::from({var_name})" # If the RAII tensor _is_ a temporary scoped to this fallback call, # simply release and steal the handle. - return f"from({raii_var}.release())" - return f"from({codegen_arg})" + return f"torch::stable::detail::from({raii_var}.release())" + return f"torch::stable::detail::from({codegen_arg})" codegen_args = get_args() ivalue_args = ( @@ -2609,7 +2610,7 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: if len(output_args) == 1 and (output := output_args[0]) is not None: # result is a single tensor dispatch_lines.writeline( - f"{output} = to(dispatch_vars[0]);" + f"{output} = torch::stable::detail::to(dispatch_vars[0]);" ) else: # result is a tuple of tensors @@ -2617,7 +2618,7 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: if output_arg is None: continue dispatch_lines.writeline( - f"{output_arg} = to(dispatch_vars[{idx}]);" + f"{output_arg} = torch::stable::detail::to(dispatch_vars[{idx}]);" ) dispatch_lines.writeline("}") diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 22d0981febecd..c4b7188bd9e62 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1330,19 +1330,6 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): including those which allow flexible fusions with epilogues. """ - def __init__( - self, - input_nodes: list[Buffer], - layout: Layout, - alpha: float, - beta: float, - input_reorder: Optional[list[int]] = None, - use_fast_accum: Optional[bool] = None, - ): - super().__init__( - input_nodes, layout, alpha, beta, input_reorder, use_fast_accum - ) - @staticmethod def add_cutlass_gemm_choices( choices: list[ChoiceCaller], diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index dc4349ad7bbf5..9b465e3d1ffab 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -161,7 +161,9 @@ def set_printer_args( # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls if kernel_type == "extern": args_to_print_or_save_extern = [ - arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg")) + arg + for arg in args_to_print_or_save + if isinstance(arg, str) and arg.startswith(("buf", "arg")) ] self.args_to_print_or_save = args_to_print_or_save_extern elif kernel_type == "cpp": @@ -172,7 +174,7 @@ def set_printer_args( else arg ) for arg in args_to_print_or_save - if arg.startswith(("buf", "arg")) + if isinstance(arg, str) and arg.startswith(("buf", "arg")) ] else: self.args_to_print_or_save = args_to_print_or_save diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 512cf89795b0d..23bf0e1bbe31a 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -7,6 +7,7 @@ import torch # noqa: TC001 from torch.utils._ordered_set import OrderedSet +from torch.utils._pallas import has_tpu_pallas from .. import config from ..runtime.runtime_utils import torch_dtype_to_jax @@ -886,6 +887,17 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove kernel_name = name or "" interpret_is_cpu = V.graph.get_current_device_or_throw().type == "cpu" + is_tpu = torch._inductor.config._debug_cpu_to_tpu_pallas + if is_tpu: + if not torch._inductor.config.pallas_take_first_jax_device_only: + raise RuntimeError( + "Pallas backend currently only supports using the first JAX device." + ) + if not has_tpu_pallas(): + raise RuntimeError( + "PALLAS_TARGET_TPU is set, but no TPU device was found. " + "Please make sure that you have a TPU available and that JAX is configured correctly." + ) interpret_literal = "True" if interpret_is_cpu else "False" # For GPU (Triton backend), import pltriton for masked loads/stores @@ -1065,19 +1077,38 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove if alias_params: code.writeline("# Convert Torch -> JAX for donated outputs") for alias_name in alias_params: - code.writeline( - f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name})" - ) + # TODO: The `jax.device_put` path is a temporary workaround for a Mosaic compiler bug + # that occurs with DLPack. Once TorchTPU provides a direct method for placing a + # `torch.Tensor` on a TPU device, this should be reverted to use the + # `jax.dlpack.from_dlpack` path. + if is_tpu: + code.writeline( + f"{alias_name}_jax = jax.device_put({alias_name}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{alias_name}_jax = jax.dlpack.from_dlpack({alias_name})" + ) code.writeline("# Convert Torch -> JAX for in-place tensors") for ptr in pointer_tail: if ptr.startswith("in_out_ptr"): - code.writeline(f"{ptr}_jax = jax.dlpack.from_dlpack({ptr})") + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline(f"{ptr}_jax = jax.dlpack.from_dlpack({ptr})") code.writeline("# Convert Torch -> JAX for inputs") for ptr in pointer_tail: if ptr.startswith("in_ptr"): - code.writeline( - f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.contiguous())" - ) + if is_tpu: + code.writeline( + f"{ptr}_jax = jax.device_put({ptr}.cpu().numpy(), device=jax.devices('tpu')[0])" + ) + else: + code.writeline( + f"{ptr}_jax = jax.dlpack.from_dlpack({ptr}.contiguous())" + ) code.writeline("# Prepare output metadata from PyTorch tensor") code.writeline( @@ -1116,9 +1147,15 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove ) for idx in copy_output_indices: name = output_params[idx] - code.writeline( - f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))" - ) + if is_tpu: + code.writeline( + f"res_cpu = jax.device_get(result_values[{idx}])" + ) + code.writeline(f"{name}.copy_(torch.from_dlpack(res_cpu))") + else: + code.writeline( + f"{name}.copy_(torch.from_dlpack(result_values[{idx}]))" + ) return code.getvalue() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 65e8f88b1c425..cf0e5bf849106 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -47,7 +47,7 @@ from ..optimize_indexing import indexing_dtype_strength_reduction from ..runtime.coordinate_descent_tuner import CoordescTuner from ..runtime.hints import DeviceProperties -from ..runtime.runtime_utils import green_text, next_power_of_2, yellow_text +from ..runtime.runtime_utils import green_text, last_power_of_2, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( cache_property_on_self, @@ -1610,10 +1610,7 @@ def benchmark_codegened_module( def _codegen_mix_order_reduction(self, node1, node2): numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1) - if not V.graph.sizevars.statically_known_gt( - numel, - rnumel, - ): + if not V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)): return self._codegen_mix_order_reduction(node2, node1) def _pick_split_size(): @@ -1625,7 +1622,10 @@ def _pick_split_size(): device_prop = DeviceProperties.create(node1.get_device()) num_sm = device_prop.multi_processor_count estimated_num_splits = num_sm * 8 - split_size = max(next_power_of_2(numel // estimated_num_splits), 16) + + # split_size is decided based on hint + numel_hint = V.graph.sizevars.size_hint(numel) + split_size = max(last_power_of_2(numel_hint // estimated_num_splits), 16) split_size = min(split_size, 128) return split_size @@ -1634,10 +1634,7 @@ def _pick_split_size(): # pyrefly: ignore [bad-assignment] metrics.codegen_mix_order_reduction += 1 - assert V.graph.sizevars.statically_known_gt( - numel, - rnumel, - ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(numel, rnumel)) # split epilogue out of node2 node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue( @@ -1681,7 +1678,6 @@ def _bench(candidate_split_size): split_size, 8, ) - # print(f"Autotuning pick split size {split_size}") kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction( kernel_features, @@ -1726,6 +1722,8 @@ def _bench(candidate_split_size): if node.get_outputs()[0].node.get_name() not in rename: node.mark_run() + V.graph.wrapper_code.make_comment("# Call mix order reduction kernel") + self.codegen_comment(node_schedule, None) # workspace args is still needed after the call kernel.call_kernel(kernel.kernel_name, deallocate_ws=False) V.graph.removed_buffers |= kernel.removed_buffers @@ -1733,7 +1731,9 @@ def _bench(candidate_split_size): # a extra round of reduction assert len(converted_nodes) == len(kernel.saved_partial_accumulate) - nsplit = (numel + split_size - 1) // split_size + nsplit = V.graph.wrapper_code.codegen_python_sizevar( + (numel + split_size - 1) // split_size + ) for idx, partial_accum in enumerate(kernel.saved_partial_accumulate): buffer_name = partial_accum.buffer_name diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4ac481478196a..9b718f0c780c1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2262,10 +2262,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): kexpr: Callable[[sympy.Expr], str] = texpr allow_block_ptr = True tma_compatibility_checker_cls = TMACompatibilityChecker - block_ptr_options_cls: type[BlockPtrOptions] = BlockPtrOptions - tensor_descriptor_options_cls: type[TensorDescriptorOptions] = ( - TensorDescriptorOptions - ) def __init__( self, @@ -2731,9 +2727,9 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: self.filter_masks(mask_vars) options_class = ( - self.block_ptr_options_cls + BlockPtrOptions if config.triton.use_block_ptr - else self.tensor_descriptor_options_cls + else TensorDescriptorOptions ) nonlocal tma_compatibility_checker if config.triton.use_block_ptr: @@ -2757,7 +2753,7 @@ def match_block_expr() -> Optional[BlockDescriptorOptions]: can_lift=can_lift, transpose_contiguous=transpose_contiguous, ) - if isinstance(options_class, TensorDescriptorOptions): + if options_class == TensorDescriptorOptions: tma_compatibility_checker = cast( TMACompatibilityChecker, tma_compatibility_checker ) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 615913933326e..41b12d05cd32e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -627,7 +627,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index c6c56c86b2c24..0eab3cac9b4a7 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -3755,6 +3755,8 @@ def __init__( self.kernel_autotune_calls = root.kernel_autotune_calls # Only store kernel src to name mapping in the main graph self.src_to_kernel = root.src_to_kernel + # Same here, only define user-defined Triton kernels in the main graph + self.user_defined_kernel_cache = root.user_defined_kernel_cache def set_launcher_fn_name(self) -> None: # This sets up the name of the function containing the launcher code of diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 681aef9afb35f..55279f393d3aa 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -341,12 +341,58 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: - sz_bytes = 0 - for node in fx_node.all_input_nodes: - if (t := node.meta.get("val")) is not None: - numel = get_size_numel(t.size()) - sz_bytes += numel * get_dtype_size(t.dtype) - return sz_bytes + """Estimate the size of a collective operation in bytes, including inputs and outputs.""" + input_bytes = None + + args, kwargs = fx_node.args, fx_node.kwargs + kwargs = dict(kwargs) + + # dont double count pre-allocated buffer passed in + kwargs.pop("out", None) + + def tensor_bytes(t) -> int: + return get_size_numel(t.size()) * get_dtype_size(t.dtype) + + def add_inp_bytes(inp: torch.fx.Node): + t = inp.meta.get("val", None) + if t is None: + return + + nonlocal input_bytes + if input_bytes is None: + input_bytes = 0 + input_bytes += tensor_bytes(t) + + pytree.tree_map_only( + torch.fx.Node, + add_inp_bytes, + (args, kwargs), + ) + + output_tensor = fx_node.meta.get("val", None) + + if input_bytes is None or output_tensor is None: + return 0 + + output_bytes = ( + get_size_numel(output_tensor.size()) * output_tensor.element_size() + ) # pyre-ignore + + return input_bytes + output_bytes + + +def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int: + """Estimate the memory footprint of a collective operation in bytes. + + This returns the total bytes that need to be live concurrently in memory. + For all_reduce, we divide by 2 since it can be done in-place. + """ + from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor as is_all_reduce, + ) + + size = estimate_fx_collective_size(fx_node) + return size if not is_all_reduce(fx_node) else size // 2 def estimate_nccl_collective_runtime_from_fx_node( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e4660f90e1eb4..45fa2d74acaed 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -303,9 +303,6 @@ def prologue_fusion_enabled() -> bool: ] ] = None -# Deprecated -split_cat_fx_passes = True - # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability. efficient_conv_bn_eval_fx_passes = False @@ -950,6 +947,11 @@ class aten_distributed_optimizations: # "benchmark": Use CUDA events with power-of-2 rounding and interpolation collective_estimator: Literal["analytical", "benchmark"] = "analytical" + # Maximum memory increase above baseline for prefetch operations + # Uses minimum of absolute cap and ratio of baseline + max_memory_increase_gb: Optional[float] = None # Absolute cap in GB + max_memory_increase_ratio: Optional[float] = None # Ratio of baseline peak memory + def parallel_compile_enabled_internally() -> bool: """ @@ -1206,6 +1208,13 @@ def decide_compile_threads() -> int: enable_autograd_for_aot: bool = False +_debug_cpu_to_tpu_pallas: bool = Config( + env_name_force="PALLAS_TARGET_TPU", default=False +) +pallas_take_first_jax_device_only: bool = Config( + env_name_force="PALLAS_TAKE_FIRST_JAX_DEVICE_ONLY", default=True +) + def get_worker_log_path() -> Optional[str]: log_loc = None diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 5641c4294356f..aba2c5182264a 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -66,7 +66,8 @@ def _schedulable_wait_node(node: torch.fx.Node) -> bool: if not is_wait_tensor(node): return False assert isinstance(node.args[0], torch.fx.Node) - assert isinstance(node.args[0].target.name(), str) + if not isinstance(node.args[0].target, Callable): + return False is_callable: bool = node.args[0].op == "call_function" coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name()) is_collective: bool = coll != NCCL_COLL.UNSUPPORTED @@ -489,15 +490,34 @@ def all_reduce_merge_fn_to_trace( return new_outs +# List of all torch dtypes for serialization through custom ops +# TODO: custom ops support list[dtype] input +_ALL_DTYPES = tuple( + [ + getattr(torch, attr) + for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) + ] +) + + @torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) def _pre_bucket_all_gather( ag_ins: list[torch.Tensor], group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[ + int + ], # dtype enum values, that inputs are converted to before all_gather rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + # Convert int indices back to torch.dtype + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -507,8 +527,14 @@ def _pre_bucket_all_gather( new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) - ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins] - torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + # View each destination slice as its output dtype, then copy + # The copy operation handles dtype conversion from input dtype to output dtype + foreach_copy_dsts_typed = [ + dst.view(out_dtype) + for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True) + ] + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened) return new_ag_out @@ -517,9 +543,14 @@ def _pre_bucket_all_gather_fake( group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[int], rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -541,12 +572,9 @@ def all_gather_merge_fn_to_trace_custom_ops( out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, ) -> list[torch.Tensor]: - ag_ins = [ - torch._prims.convert_element_type(_ag_in, out_dtype) - if _ag_in.dtype != out_dtype - else _ag_in - for _ag_in, out_dtype in zip(_ag_ins, out_dtypes) - ] + # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion + # by viewing destination slices as output dtypes and letting copy do the conversion + ag_ins = _ag_ins ins_sizes = [ag_in.shape for ag_in in ag_ins] ins_split_sizes_bytes = [ ag_in.numel() * out_dtype.itemsize @@ -557,8 +585,13 @@ def all_gather_merge_fn_to_trace_custom_ops( _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes ] ag_input_numel = sum(ins_split_sizes) + + # Convert out_dtypes to indices for custom_op + # TODO: custom ops support list[dtype] input + out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes] + new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( - ag_ins, group_size, group_name, dtype, rank + ag_ins, group_size, group_name, dtype, out_dtype_ints, rank ) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) wait_tensor = torch.ops.c10d_functional.wait_tensor( @@ -721,6 +754,20 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] return replacements, new_nodes +def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool: + node_in = n.args[0] + return ( + is_all_gather_into_tensor(n) + and isinstance(node_in, torch.fx.Node) + and node_in.op == "call_function" + and ( + node_in.target is torch.ops.prims.convert_element_type.default + or node_in.target is torch.ops.aten._to_copy.default + ) + and len(node_in.users) == 1 + ) + + def process_collective_bucket( g: torch.fx.Graph, bucket_nodes: list[torch.fx.Node], @@ -755,13 +802,7 @@ def process_collective_bucket( # Handle convert_element_type operations (for all_gather) node_in = n.args[0] - if ( - is_all_gather_into_tensor(n) - and isinstance(node_in, torch.fx.Node) # Add type check - and node_in.op == "call_function" - and node_in.target is torch.ops.prims.convert_element_type.default - and len(node_in.users) == 1 - ): + if has_mergeable_all_gather_convert_dtype(n): ag_node_to_pre_nodes[n].append(node_in) node_in = node_in.args[0] diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 9db694f1d8629..021abb0d6b13b 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -957,3 +957,92 @@ def repl(inp, other): pass_dict=pass_patterns[1], extra_check=_other_is_broadcasted_in_dim, )(div_softmax_pattern) + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, torch.fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedious. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_graph_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + # pyrefly: ignore [bad-argument-type] + pass_dict=patterns, + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise operation in joint graph. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + # pyrefly: ignore # bad-assignment + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + # Create a replacement that uses torch.where for the pointwise operation + def repl_fn(shape, background_val, dim, selector, val): + # Create a tensor of indices for the scatter dimension + length = shape[dim] + indices = torch.arange(length, device=selector.device, dtype=torch.int64) + + # Reshape indices to have size 'length' at dim, then broadcast + view_shape = [1] * len(shape) + view_shape[dim] = length + indices_view = indices.view(*view_shape) + + # Broadcast selector to match full tensor shape + selector_expanded = selector.expand(shape) + + # Create a mask for where to scatter + mask = selector_expanded == indices_view + + # Use torch.where to implement the scatter pointwise operation + return torch.where(mask, val, background_val) + + # replace the scatter operation with pointwise equivalent + # pyrefly: ignore [bad-argument-type] + match.replace_by_example(repl_fn, [shape, background_val, dim, selector, val]) diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index f5c131a7eab96..c8af70dc598f4 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -172,6 +172,8 @@ def __init__( max_coll_distance=0, custom_runtime_estimation=None, collective_estimator="analytical", + max_memory_increase_gb=None, + max_memory_increase_ratio=None, ) self.module_bucket_plans = module_bucket_plans self.nodes_in_subgraph: list[list[fx.Node]] = [] diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 4060a29c7c3db..b5ef930b8fa8f 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -3,15 +3,17 @@ from dataclasses import dataclass from typing import Any, Literal, Optional +import torch import torch.fx as fx from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( + _schedulable_wait_node, bucket_key, BucketMode, + has_mergeable_all_gather_convert_dtype, is_all_gather_into_tensor as is_all_gather, is_reduce_scatter_tensor as is_reduce_scatter, - is_wait_tensor, ) from torch._inductor.fx_passes.overlap_scheduling import ( CollBucket, @@ -50,12 +52,12 @@ def __call__(self, reason: str, *args: Any) -> None: def is_collective_or_wait(n: fx.Node) -> bool: """Check if node is a collective start or wait.""" - if is_wait_tensor(n): + if _schedulable_wait_node(n): return True # Collective starts have exactly one use: the wait_tensor if len(n.users) == 1: user = next(iter(n.users.keys())) - if is_wait_tensor(user): + if _schedulable_wait_node(user): return True return False @@ -185,7 +187,7 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: if node in self.collective_info and get_group_name(node) == pg: node_type = "starts" hiding_nodes |= self.collective_info[node].hiding_nodes - elif is_wait_tensor(node): + elif _schedulable_wait_node(node): wait_input = node.args[0] if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: node_type = "waits" @@ -207,6 +209,7 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: prev_event = event position += 1 + return head def _populate_node_to_event(self, pg: str) -> None: @@ -231,7 +234,6 @@ def _add_hiding_interval_constraints(self) -> None: self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn) def bucket_collectives(self) -> None: - """Main entry point for bucketing collectives.""" # Group collectives by PG first pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) for start in self.collective_info: @@ -281,6 +283,15 @@ def bucket_collectives(self) -> None: # Apply topological sort with all dependencies from torch._dynamo.graph_deduplication import _stable_topological_sort + for n, deps in additional_deps.items(): + torch._check( + not n._erased, lambda: f"Erased node deps not transferred: {n}" + ) + for d in deps: + torch._check( + not d._erased, lambda: f"Erased node deps not transferred: {d}" + ) + _stable_topological_sort(self.graph, additional_deps) # After topological sort, preserve dependencies using effect tokens @@ -315,7 +326,7 @@ def _find_buckets( # Sort collectives by node index for efficient distance checking sorted_collectives = sorted(collective_group, key=lambda n: self.node_idx[n]) - for start_node in sorted_collectives: + for i, start_node in enumerate(sorted_collectives): if start_node in processed: continue @@ -325,25 +336,17 @@ def _find_buckets( total_bytes=self.collective_info[start_node].size_bytes, ) processed.add(start_node) - start_node_idx = self.node_idx[start_node] # Check candidates in sorted order, break when beyond max distance - for candidate in sorted_collectives: + for candidate in sorted_collectives[i + 1 : i + 1 + self.max_coll_distance]: if candidate in processed: continue - candidate_idx = self.node_idx[candidate] - # Check if candidate is within max distance from the bucket start - distance = abs(candidate_idx - start_node_idx) - if distance > self.max_coll_distance: - # Since sorted, all remaining candidates will be too far - if candidate_idx > start_node_idx: - break - continue - candidate_bytes = self.collective_info[candidate].size_bytes + # proxy on memory use, if we see a too large bucket, + # dont look for another, later bucket if bucket_info.total_bytes + candidate_bytes > max_bucket_bytes: - continue + break if self._can_add_to_bucket(bucket_info, candidate): bucket_info.collectives.append(candidate) @@ -762,6 +765,11 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: old_starts = list(bucket) old_waits = [self.collective_info[n].wait_node for n in bucket] + fused_convert_dtypes = [] + for n in old_starts: + if has_mergeable_all_gather_convert_dtype(n): + fused_convert_dtypes.append(n.args[0]) + # Find where to place the bucketed operations next_node = bucket[0] while next_node in bucket: @@ -795,7 +803,7 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: ) # Get new nodes - new_waits = [n for n in new_nodes if is_wait_tensor(n)] + new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] assert len(new_waits) == 1 new_wait = new_waits[0] @@ -809,6 +817,22 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: for old_wait in old_waits: erased_to_new[old_wait] = new_wait + # Handle convert_element_type nodes that were fused and erased + # The bucketed operation may have a _pre_bucket op that handles dtype conversion + if fused_convert_dtypes: + # all gather bucketing may fuse in dtype conversion into the bucketing + # if so, we need to transfer hiding deps from the old dtype conversion + # to the new bucketing node + new_convert_dtypes_node = new_start.kwargs["out"] + assert isinstance(new_convert_dtypes_node, fx.Node) + assert ( + new_convert_dtypes_node.target + == torch.ops.bucketing._pre_bucket_all_gather.default + ) + + for n in fused_convert_dtypes: + erased_to_new[n] = new_convert_dtypes_node + # Transfer all dependencies from old nodes to new nodes self.aug_graph.transfer_erased_node_deps(erased_to_new) diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 0649e36f23361..436a3ab0db81b 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -11,13 +11,9 @@ import torch import torch.fx as fx from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.comm_analysis import estimate_fx_collective_size +from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor -from torch._inductor.fx_passes.memory_estimator import ( - _is_releasable, - build_memory_profile, - MemoryTracker, -) +from torch._inductor.fx_passes.memory_estimator import MemoryTracker from torch.fx.operator_schemas import normalize_function from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import _disable_current_modes @@ -30,6 +26,27 @@ from ..pattern_matcher import stable_topological_sort +@dataclass +class WhyNoOverlap: + """Track reasons why a collective cannot overlap with compute.""" + + compute_name: str + collective_name: str + + def __init__(self, compute_node: fx.Node, collective_node: fx.Node) -> None: + self.compute_name = compute_node.name + self.collective_name = collective_node.name + + def __call__(self, reason: str, *args: Any) -> None: + if log.isEnabledFor(logging.DEBUG): + log.debug( + "cannot overlap %s with %s: " + reason, # noqa: G003 + self.collective_name, + self.compute_name, + *args, + ) + + def get_group_name(n: fx.Node) -> str: """Extract the group name from a collective operation node.""" opt_args_kwargs = normalize_function( @@ -45,21 +62,26 @@ def get_group_name(n: fx.Node) -> str: def get_custom_estimation( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + override_size: int | None = None, ) -> float | None: if custom_runtime_estimation is None: return None - return custom_runtime_estimation(n) + return custom_runtime_estimation(n, override_size) def estimate_collective_time( n: fx.Node, override_size: int | None = None, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: """Estimate the runtime of a collective operation, optionally with an overridden size.""" - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, override_size) + ) is not None: return est # Use analytical model (benchmarking is handled separately in alignment) @@ -99,7 +121,8 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: def benchmark_node_with_cache_key( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> tuple[float, str | None]: """Benchmark a compute node and return (runtime, cache_key).""" assert is_compute_node(n) @@ -142,7 +165,9 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: if unbacked_tensor: return 0, key - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, None) + ) is not None: set_cached_node_time(key, est) return est, key @@ -154,7 +179,8 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: def benchmark_node( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0] @@ -236,8 +262,10 @@ def __init__( insert_overlap_deps: bool, compute_overlap_multipler: float, max_coll_distance: int, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, collective_estimator: Literal["analytical", "benchmark"], + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, ): self.gm = gm self.graph = gm.graph @@ -262,18 +290,47 @@ def __init__( self.collective_info: dict[fx.Node, CollectiveInfo] = {} self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet() - # Memory tracking using abstracted MemoryTracker - self.original_peak_memory = max( - build_memory_profile(self.graph, _is_releasable) + # Identify compute nodes early (needed for baseline memory computation) + self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] + self.current_compute_index = 0 + + # Compute baseline memory profile from original schedule + self.original_mem_before_compute_index: list[int] = [] + self.original_peak_memory = self._compute_baseline_memory() + + # Maximum allowed peak memory = baseline + max(absolute, ratio * baseline) + # When both limits are specified, use the more permissive one + memory_increase_bytes = None + if max_memory_increase_gb is not None: + memory_increase_bytes = gb_to_bytes(max_memory_increase_gb) + if max_memory_increase_ratio is not None: + ratio_increase = int(self.original_peak_memory * max_memory_increase_ratio) + memory_increase_bytes = ( + max(memory_increase_bytes, ratio_increase) + if memory_increase_bytes is not None + else ratio_increase + ) + if memory_increase_bytes is None: + memory_increase_bytes = 0 + + self.allowed_peak_memory_bytes = ( + self.original_peak_memory + memory_increase_bytes ) + + # Track cumulative prefetch memory at each compute index + # When we prefetch a collective at compute index i that will be used at index j, + # it adds memory from i to j, so we need to track this cumulative effect + self.cumulative_prefetch_mem_by_compute_index: list[int] = [ + 0 for _ in range(len(self.compute_nodes)) + ] + self.memory_tracker = MemoryTracker(self.graph) self.wait_to_start: dict[fx.Node, fx.Node] = {} self._identify_collectives() + self.wasted_compute = 0.0 self.compute_index_domination = self._calculate_compute_node_domination_index() - self.compute_nodes = [n for n in self.nodes if is_compute_node(n)] - self.current_compute_index = 0 # Scheduling state self.potentially_hidden_collectives = ( @@ -302,6 +359,88 @@ def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]: return ancestors + def _compute_baseline_memory(self) -> int: + """ + Simulate the original schedule to compute baseline memory profile. + Returns the peak memory observed during simulation. + """ + baseline_tracker = MemoryTracker(self.graph) + + last_compute_max_memory = 0 + peak_memory = 0 + + for node in self.nodes: + baseline_tracker.schedule_node(node) + current_mem = baseline_tracker.current_memory_bytes + + # Record the max memory between this and previous compute node + last_compute_max_memory = max(last_compute_max_memory, current_mem) + + if is_compute_node(node): + self.original_mem_before_compute_index.append(last_compute_max_memory) + last_compute_max_memory = current_mem + + peak_memory = max(peak_memory, current_mem) + + return peak_memory + + def _prefetch_would_exceed_memory_budget(self, start_node: fx.Node) -> bool: + """ + Check if prefetching this collective would exceed memory budget at ANY compute node + between now and when it's used. + """ + info = self.collective_info[start_node] + size = info.size_bytes + + domination_index = self.compute_index_domination[start_node] + + # If off-path, assume it doesn't increase memory + if domination_index == sys.maxsize: + return False + + # check current mem + if ( + self.memory_tracker.current_memory_bytes + size + > self.allowed_peak_memory_bytes + ): + return True + + start_index = self.current_compute_index + + # then, check future mem + for compute_idx in range(start_index, domination_index): + cumulative_prefetch = self.cumulative_prefetch_mem_by_compute_index[ + compute_idx + ] + + # Check 1: Would cumulative prefetch exceed in-flight limit? + if (cumulative_prefetch + size) > self.max_in_flight_bytes: + return True + + # Check 2: Would total memory (baseline + cumulative prefetch) exceed budget? + baseline_mem = self.original_mem_before_compute_index[compute_idx] + projected = baseline_mem + cumulative_prefetch + size + + if projected > self.allowed_peak_memory_bytes: + return True + + return False + + def _update_cumulative_prefetch_memory( + self, collective: fx.Node, info: CollectiveInfo + ) -> None: + """ + Update cumulative prefetch memory for all compute indices this collective will be live. + """ + domination_index = self.compute_index_domination[collective] + if domination_index == sys.maxsize: + return + + for compute_idx in range(self.current_compute_index, domination_index): + self.cumulative_prefetch_mem_by_compute_index[compute_idx] += ( + info.size_bytes + ) + def off_compute_path(self, n: fx.Node) -> bool: """Check if a node is off the compute path (doesn't block any compute).""" return self.compute_index_domination[n] == sys.maxsize @@ -318,7 +457,7 @@ def _identify_collectives(self) -> None: info = CollectiveInfo( start_node=start, wait_node=node, - size_bytes=estimate_fx_collective_size(start), + size_bytes=estimate_fx_collective_memory_footprint(start), estimated_time_ms=coll_time_ms, exposed_time_ms=coll_time_ms, # Initially fully exposed ) @@ -431,7 +570,10 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( # Benchmark CUDA events (non-deterministic, needs alignment) # Skip collectives with custom estimation for n in collective_nodes: - if get_custom_estimation(n, self.custom_runtime_estimation) is not None: + if ( + get_custom_estimation(n, self.custom_runtime_estimation, None) + is not None + ): continue # Benchmark actual size @@ -518,16 +660,14 @@ def run(self) -> torch.fx.GraphModule: if node in self.scheduled: continue - if is_compute_node(node): - self._handle_compute(node) + if node.op == "placeholder": + self._schedule(node) elif node in self.collective_info: self._handle_collective_start(node) elif _schedulable_wait_node(node): self._handle_wait(node) - elif node.op == "placeholder": - self._schedule(node) else: - self._handle_other(node) + self._handle_compute_or_other(node) self._reorder_graph() @@ -566,9 +706,58 @@ def _add_effect_tokens_for_overlap(self) -> None: if additional_deps: preserve_node_ordering(self.graph, additional_deps) - def _handle_other(self, node: fx.Node) -> None: + def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None: + """Get runtime estimation for a node in ms. Returns None if no estimation is available.""" + + # TODO: non custom estimation of aten nodes, potentially requires notion of fusion group + if is_compute_node(node): + return benchmark_node(node, self.custom_runtime_estimation) + + if self.custom_runtime_estimation is None: + return None + + return self.custom_runtime_estimation(node, None) + + def _reduce_exposed_time_of_in_flight_collectives( + self, node: fx.Node, available_compute: float + ) -> float: + """Reduce exposed time of in-flight collectives using available compute time and return available time""" + + # TODO: separate overlap time per process group + for info in self.in_flight.values(): + if info.exposed_time_ms == 0: + continue + overlap_amount = min(info.exposed_time_ms, available_compute) + info.exposed_time_ms -= overlap_amount + available_compute -= overlap_amount + info.hiding_nodes.add(node) + if available_compute == 0: + break + return available_compute + + def _handle_compute_or_other(self, node: fx.Node) -> None: + """Handle scheduling compute or other nodes and attempt to overlap with collectives.""" + runtime_estimate = self.get_non_collective_runtime_estimate(node) + + # TODO: we could consider skipping overlapping for overlapable, unary chains to collectives. + # using these nodes for overlap prevents bucketing. potentially if chain time < latency + if runtime_estimate is None: + assert not is_compute_node(node), "should have estimate for compute nodes" + self._schedule(node) + return + + available_compute = runtime_estimate * self.compute_overlap_multipler + initial_compute = available_compute # Track initial compute time for wasted compute/path calculations + + available_compute = self._reduce_exposed_time_of_in_flight_collectives( + node, available_compute + ) + self._schedule_collectives_for_overlap(node, available_compute, initial_compute) self._schedule(node) + if is_compute_node(node): + self.current_compute_index += 1 + def _schedule(self, node: fx.Node) -> None: """Schedule a node.""" assert node not in self.scheduled @@ -639,9 +828,8 @@ def _should_force_wait_for_memory(self) -> bool: """Check if we need to force a wait due to memory pressure""" if not self.in_flight: return False - return self.in_flight_bytes >= self.max_in_flight_bytes or ( - self.memory_tracker.current_memory_bytes - self.original_peak_memory - ) > gb_to_bytes(1.0) + + return self.in_flight_bytes >= self.max_in_flight_bytes def _force_oldest_wait(self) -> None: """Schedule the oldest in flight wait""" @@ -687,81 +875,54 @@ def _handle_wait(self, node: fx.Node) -> None: del self.in_flight[coll_start] self._schedule(node) - def _handle_compute(self, node: fx.Node) -> None: - """Handle scheduling compute and finding overlaps.""" - - compute_time = benchmark_node(node, self.custom_runtime_estimation) - available_compute = compute_time * self.compute_overlap_multipler - - # TODO: separate overlap time per process group - # First reduce exposed time of in-flight collectives - for info in self.in_flight.values(): - if info.exposed_time_ms == 0: - continue - overlap_amount = min(info.exposed_time_ms, available_compute) - info.exposed_time_ms -= overlap_amount - available_compute -= overlap_amount - info.hiding_nodes.add(node) - if available_compute == 0: - break - - # Then, look for unscheduled collectives we can overlap - if available_compute: - self._schedule_collectives_for_overlap(node, available_compute) - - self._schedule(node) - self.current_compute_index += 1 - def _schedule_collectives_for_overlap( - self, compute_node: fx.Node, available_compute_time: float + self, compute_node: fx.Node, available_compute_time: float, initial_time: float ) -> None: """Opportunistically schedule collectives that can be hidden by compute.""" + if available_compute_time == 0: + return + + reduced_time = initial_time - available_compute_time compute_ancestors = self.node_ancestors[compute_node] - # Filter collectives by distance and compute index domination - possible_collectives = [] - for collective in self.unscheduled_collectives: - distance = abs(self.node_idx[compute_node] - self.node_idx[collective]) - if distance > self.max_node_distance: + # Compile-time filtering: limit candidates by distance to bound O(compute * collectives) cost + candidates = [] + for i, collective in enumerate(self.unscheduled_collectives): + if i > self.max_node_distance: break - # Skip collectives that are too far ahead in compute index, but allow scheduling - # collectives which are off compute path (which typically release memory) - # TODO: we could potentially be more strict about limiting the amount of - # pre-fetched memory before memory peak, and adjust allowed collective mem. - if not self.off_compute_path(collective): - if ( - self.compute_index_domination[collective] - - self.current_compute_index - ) > self.max_compute_pre_fetch: - continue + if ( + not self.off_compute_path(collective) + and self.compute_index_domination[collective] + - self.current_compute_index + > self.max_compute_pre_fetch + ): + continue - possible_collectives.append(collective) + candidates.append(collective) - possible_collectives = sorted( - possible_collectives, + candidates = sorted( + candidates, key=lambda n: (self.compute_index_domination[n], self.node_idx[n]), ) - log.debug( - "Scheduling collectives for overlap: compute_node=%s, available_time=%.2f ms, candidates=%d, current_memory=%d bytes", - compute_node.name, - available_compute_time, - len(possible_collectives), - self.memory_tracker.current_memory_bytes, - ) - - for collective in possible_collectives: + for collective in candidates: if available_compute_time == 0: break + why = WhyNoOverlap(compute_node, collective) info = self.collective_info[collective] - # Skip if compute depends on collective or vice versa if ( collective in compute_ancestors or compute_node in self.node_ancestors[collective] ): + why("dependency conflict") + continue + + # Check if prefetching would exceed memory budget + if self._prefetch_would_exceed_memory_budget(collective): + why("prefetch would exceed memory budget") continue while ( @@ -772,10 +933,11 @@ def _schedule_collectives_for_overlap( self._force_oldest_wait() if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes: + why("in-flight memory limit") continue # Check if we can reach this collective without scheduling compute, other collectives, or waits - path = self._find_schedulable_path(collective, compute_node) + path = self._find_schedulable_path(collective, compute_node, why) if path is None: continue @@ -787,31 +949,41 @@ def _schedule_collectives_for_overlap( self.current_compute_index, ) - # Schedule path to this collective + # Track compute runtime of nodes we must schedule to reach collective and + # add back available overlap time corresponding to prior in-flight collectives + path_estimates = [self.get_non_collective_runtime_estimate(p) for p in path] + path_time = sum(p for p in path_estimates if p is not None) + additional_time = min(path_time, reduced_time) + reduced_time -= additional_time + available_compute_time += additional_time + self._schedule_path_to_collective(path, compute_node) self._handle_collective_start(collective) + self._update_cumulative_prefetch_memory(collective, info) - # Update the exposed time for this newly scheduled collective - # after scheduling, which will account for latency reduction of bucketing + # Update exposed time overlap_amount = min(available_compute_time, info.exposed_time_ms) info.exposed_time_ms -= overlap_amount info.hiding_nodes.add(compute_node) available_compute_time -= overlap_amount + self.wasted_compute += available_compute_time + def _find_schedulable_path( - self, target: fx.Node, curr_compute_node: fx.Node | None + self, target: fx.Node, curr_compute_node: fx.Node | None, why: WhyNoOverlap ) -> OrderedSet[fx.Node] | None: """Find path to target by collecting unscheduled dependencies.""" - - # TODO - following path faster than doing set difference here + # Get unscheduled ancestors unscheduled_ancestors = self.node_ancestors[target] - self.scheduled # only schedule non distributed, non compute nodes for node in unscheduled_ancestors: if is_compute_node(node): + why("path blocked by compute node %s", node.name) return None if node in self.unscheduled_collectives: + why("path blocked by unscheduled collective %s", node.name) return None # if we schedule a wait tensor whose start collective is hidden by the @@ -823,8 +995,13 @@ def _find_schedulable_path( if _schedulable_wait_node(node): info = self.collective_info[self.wait_to_start[node]] if info.hiding_nodes and curr_compute_node not in info.hiding_nodes: + why( + "path blocked by wait node %s with different hiding compute", + node.name, + ) continue elif node not in self.potentially_hidden_waits: + why("path blocked by wait node %s that could be hidden", node.name) continue return None @@ -899,13 +1076,21 @@ def _reorder_graph(self) -> None: if c.exposed_time_ms == c.estimated_time_ms ] - potentially_hidden_collectives = self.compute_potential_hidden_collectives( - limit_coll_per_compute=True - ) + potentially_hidden_collectives = self.compute_potential_hidden_collectives() bad_exposed = [ c for c in exposed if c.start_node in potentially_hidden_collectives ] + # Compute total exposed and potential exposed time + total_exposed = sum(c.exposed_time_ms for c in self.collective_info.values()) + hideable_exposed_ms = sum( + self.collective_info[c].exposed_time_ms + for c in potentially_hidden_collectives + ) + total_potential_exposed = sum( + c.estimated_time_ms for c in self.collective_info.values() + ) + counters["inductor"]["overlap_scheduling_exposed"] += len(exposed) counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed) counters["inductor"]["overlap_scheduling_potentially_hidden"] += len( @@ -916,12 +1101,18 @@ def _reorder_graph(self) -> None: log.info( "Overlap scheduling results: exposed=%d, bad_exposed=%d, potentially_hidden=%d, " - "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes", + "original_peak_memory=%d bytes, rescheduled_peak_memory=%d bytes, " + "total_exposed_ms=%.2f, hideable_exposed_ms=%.2f, total_potential_exposed_ms=%.2f, " + "wasted_compute_ms=%.2f", len(exposed), len(bad_exposed), len(potentially_hidden_collectives), self.original_peak_memory, self.memory_tracker.peak_memory, + total_exposed, + hideable_exposed_ms, + total_potential_exposed, + self.wasted_compute, ) self.reorder_graph() @@ -936,31 +1127,25 @@ def _bucket_collectives(self) -> None: collective_info=self.collective_info, node_ancestors=self.node_ancestors, scheduled=self.scheduled, - max_bucket_memory_gb=1.0, # Could make this configurable + max_bucket_memory_gb=2.0, # Could make this configurable max_coll_distance=self.max_node_distance, insert_overlap_deps=self.insert_overlap_deps, ) bucketer.bucket_collectives() def compute_potential_hidden_nodes( - self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False + self, nodes_to_check: Iterable[fx.Node] ) -> dict[fx.Node, fx.Node]: """ Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node """ - used_compute_nodes: OrderedSet[fx.Node] = OrderedSet() - def could_be_hidden(start: fx.Node) -> fx.Node | None: for compute_node in self.compute_nodes: - if limit_coll_per_compute and compute_node in used_compute_nodes: - continue if ( start not in self.node_ancestors[compute_node] and compute_node not in self.node_ancestors[start] ): - if limit_coll_per_compute: - used_compute_nodes.add(compute_node) return compute_node return None @@ -976,32 +1161,29 @@ def could_be_hidden(start: fx.Node) -> fx.Node | None: return potentially_hidden - def compute_potential_hidden_collectives( - self, limit_coll_per_compute: bool = False - ) -> dict[fx.Node, fx.Node]: + def compute_potential_hidden_collectives(self) -> dict[fx.Node, fx.Node]: """Compute which collective operations could be hidden by compute.""" - return self.compute_potential_hidden_nodes( - self.collective_info.keys(), limit_coll_per_compute - ) + return self.compute_potential_hidden_nodes(self.collective_info.keys()) - def compute_potential_hidden_waits( - self, limit_coll_per_compute: bool = False - ) -> dict[fx.Node, fx.Node]: + def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]: """Compute which wait operations could be hidden by compte.""" wait_nodes = [info.wait_node for info in self.collective_info.values()] - return self.compute_potential_hidden_nodes(wait_nodes, limit_coll_per_compute) + return self.compute_potential_hidden_nodes(wait_nodes) def schedule_overlap_bucketing( gm: torch.fx.GraphModule, - max_in_flight_gb: float = 2.0, - max_compute_pre_fetch: int = 5, + max_in_flight_gb: float = 5, + max_compute_pre_fetch: int = 200, collective_bucketing: bool = False, insert_overlap_deps: bool = False, compute_overlap_multipler: float = 1.0, - max_coll_distance: int = 1000, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + max_coll_distance: int = 200, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, collective_estimator: Literal["analytical", "benchmark"] = "analytical", + max_memory_increase_gb: float | None = 1.0, + max_memory_increase_ratio: float | None = 0.05, ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. @@ -1009,19 +1191,22 @@ def schedule_overlap_bucketing( gm: Input graph module to optimize. max_in_flight_gb: Maximum GB of concurrent collective data. Too much in flight memory can cause memory fragmentation within the CUDA Caching Allocator. - max_compute_pre_fetch: Maximum compute node prefetch distance. + max_compute_pre_fetch: Maximum mm nodes to pre fetch. Note: should already be limited by max_in_flight_gb and + max_memory_increase_gb collective_bucketing: Enable overlap-preserving collective bucketing. insert_overlap_deps: Insert overlap dependencies using control deps operator. This should only be used if compiling with inductor, or for subsequent passes before removing the ops prior to execution. compute_overlap_multipler: Scale factor for compute time used to hide collectives. This can be used to address over or under aggressive overlapping. - max_coll_distance: Maximum node distance for overlap or bucketing. Mostly intended to reduce compile time. + max_coll_distance: Maximum pre fetch or bucketing candidates. Mainly intended for compile time custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node. If None, uses default estimations. This is currently limited to collectives and compute nodes. collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas, "benchmark" uses CUDA events with power-of-2 rounding and interpolation. + max_memory_increase_gb: Maximum GB increase above baseline memory (absolute cap). If None, no absolute limit. + max_memory_increase_ratio: Maximum increase as ratio of baseline peak memory. If None, no ratio limit. + Uses minimum of absolute and ratio limits when both are specified. """ - return OverlapScheduler( gm, compute_overlap_multipler=compute_overlap_multipler, @@ -1032,4 +1217,6 @@ def schedule_overlap_bucketing( collective_bucketing=collective_bucketing, insert_overlap_deps=insert_overlap_deps, collective_estimator=collective_estimator, + max_memory_increase_gb=max_memory_increase_gb, + max_memory_increase_ratio=max_memory_increase_ratio, ).run() diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index e0362f2aaafd4..a21e78821e52b 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -16,13 +16,13 @@ from torch._decomp import register_decomposition from torch._dynamo.utils import counters from torch._inductor import comms -from torch._inductor.virtualized import ops +from torch._inductor.virtualized import ops # noqa: F401 from torch._logging import trace_structured from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq from torch.utils._ordered_set import OrderedSet -from .. import config, ir, pattern_matcher +from .. import config, ir, pattern_matcher # noqa: F401 from ..codegen.common import custom_backend_passes from ..comms import remove_fsdp2_unsharded_param_graph_input_usage from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage @@ -303,6 +303,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): "custom_runtime_estimation", "insert_overlap_deps", "collective_estimator", + "max_memory_increase_gb", + "max_memory_increase_ratio", ) for key in config_keys: if (val := getattr(dist_opts, key)) is not None: @@ -802,95 +804,6 @@ def is_valid_mm_plus_mm(match: Match): return True -def scatter_upon_const_tensor_extra_check(m): - if not config.optimize_scatter_upon_const_tensor: - return False - full_shape = m.kwargs["shape"] - selector = m.kwargs["selector"] - dim = m.kwargs["dim"] - if dim < 0: - dim += len(full_shape) - - selector_ft = selector.meta["val"] - assert selector_ft.dim() == len(full_shape) - - for idx, select_sz, full_sz in zip( - itertools.count(), selector_ft.shape, full_shape - ): - if idx == dim: - continue - - # TODO: the pattern can be updated to support the case that index tensor - # is shorter. But that will need a more complex condition expression - # especially for multi-dimensional tensors. - # Skip it for now. - if isinstance(full_sz, fx.Node): - full_sz = full_sz.meta["val"] - if select_sz < full_sz: - return False - - # Actually we can support small size larger than 1. It would be a bit - # tedius. E.g., we load all the index values (not many) and compare - # them with the position in tensor to decide what value to return. - return selector_ft.size(dim) == 1 - - -@register_lowering_pattern( - CallFunction( - aten.scatter.value, - CallFunction( - aten.full, - KeywordArg("shape"), - KeywordArg("background_val"), - dtype=KeywordArg("dtype"), - ), - KeywordArg("dim"), - KeywordArg("selector"), - KeywordArg("val"), # scalar value - ), - extra_check=scatter_upon_const_tensor_extra_check, -) -def scatter_upon_const_tensor( - match: Match, shape, background_val, dtype, dim, selector, val -): - """ - Match the pattern of full+scatter into a pointwise. - - TODO: Right now the scatter value must be a scalar. But we could support it - when it is a tensor as well. - """ - from torch._inductor import metrics - - # Check if inputs are tensors instead of inductor IR nodes - if isinstance(selector, torch.Tensor): - # Return a fake tensor with the proper shape that this operator is intended to return - device = selector.device if hasattr(selector, "device") else torch.device("cpu") - return torch.empty(shape, dtype=dtype, device=device) - - # pyrefly: ignore [bad-assignment] - metrics.num_matches_for_scatter_upon_const_tensor += 1 - - selector_loader = selector.make_loader() - - def inner_fn(idx): - selector_idx = list(idx) - selector_idx[dim] = 0 - - selector = selector_loader(selector_idx) - return ops.where( - selector == ops.index_expr(idx[dim], torch.int64), - ops.constant(val, dtype), - ops.constant(background_val, dtype), - ) - - return ir.Pointwise.create( - device=selector.get_device(), - dtype=dtype, - inner_fn=inner_fn, - ranges=shape, - ) - - @register_lowering_pattern( CallFunction( aten.add, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 1eaab41130675..517d6c3e39d1b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1114,11 +1114,35 @@ def constant_name(self, name: str, device_override: Optional[torch.device]) -> s with torch.utils._python_dispatch._disable_current_modes(): # caller might have OrderedSet fake tensor mode which will create a fake tensor # when calling .to, so unset modes here - return self.allocate_non_dup_const_name( + non_dup_const_name = self.allocate_non_dup_const_name( f"{name}_{device_override.type}{device_override.index or 0}", self.constants[name].to(device_override), ) + assert non_dup_const_name in self.constants, ( + f"{non_dup_const_name} should be in V.graph.constants already" + ) + + # register device-copied buffers and parameters to graph as well + # to codegen correct torch::aot_inductor::ConstantType for them rather than `Unknown` + if any( + name == normalize_name(buffer_name) + for buffer_name in self.named_buffers + ): + self.named_buffers[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + if any( + name == normalize_name(param_name) + for param_name in self.named_parameters + ): + self.named_parameters[non_dup_const_name] = self.constants[ + non_dup_const_name + ] + + return non_dup_const_name + # pyrefly: ignore [bad-override] def placeholder( self, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 72d8383d2b812..0f29d38cb44d0 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -530,6 +530,16 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]: return list(sym_vars) +def try_get_name(x): + if isinstance(x, TensorBox): + x = x.data + if isinstance(x, BaseView): + x = x.unwrap_view() + if isinstance(x, StorageBox): + x = x.data + return x.get_name() if isinstance(x, Buffer) else None + + class IRNode: """Base class for all intermediate representation (IR) nodes in TorchInductor. @@ -1435,9 +1445,7 @@ def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]: strides = V.graph.sizevars.stride_hints( j, reduction_vars, list(ranges1.keys()) ) - # A 0 stride does not make a reduction contiguous. - # This can happen when the reduction ranges contains a 1. - outer = all(s == 0 or s > 1 for s in strides) + outer = all(s > 1 for s in strides) if outer: num_outer += 1 else: @@ -7431,6 +7439,8 @@ class DeviceCopy(ExternKernelOut): def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode: if ( not x.is_extern() + # Can not apply this optimization if x has been mutated + and try_get_name(x) not in V.graph.mutated_buffers and all(r in V.graph.constants for r in x.get_read_names()) and not config.aot_inductor.use_runtime_constant_folding ): diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 23878f757cc5e..12cc68dcb9844 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -350,9 +350,51 @@ def autotune_custom_op( return selected_result +def _generate_dynamic_configs( + tensor_inputs: list[Buffer], + config_generator: Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]], + default_impl: Callable[..., Any], + operation_name: str, +) -> list[CustomOpConfig]: + """Generate configs dynamically based on input tensors at lowering time.""" + import inspect + + sig = inspect.signature(default_impl) + param_names = list(sig.parameters.keys()) + + with V.fake_mode: + fake_tensors = [] + for inp in tensor_inputs: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + fake_tensors.append(fake_tensor) + + fake_tensors_dict = dict(zip(param_names, fake_tensors)) + + configs = config_generator(fake_tensors_dict) + + if not isinstance(configs, (list, tuple)): + raise TypeError( + f"config_generator must return a list or tuple of CustomOpConfig, " + f"got {type(configs)}" + ) + if not configs: + raise ValueError(f"config_generator returned empty list for {operation_name}. ") + + return list(configs) + + def register_custom_op_autotuning( custom_op: torch._library.custom_ops.CustomOpDef, - configs: Union[list[CustomOpConfig], list[Callable[..., Any]]], + configs: Optional[Union[list[CustomOpConfig], list[Callable[..., Any]]]] = None, + config_generator: Optional[ + Callable[[dict[str, torch.Tensor]], list[CustomOpConfig]] + ] = None, name: Optional[str] = None, input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None, ) -> None: @@ -361,11 +403,15 @@ def register_custom_op_autotuning( Args: custom_op: Custom operation (decorated function from @torch.library.custom_op) - configs: List of CustomOpConfig objects + configs: List of CustomOpConfig objects for static inputs. Mutually exclusive with config_generator. + config_generator: Dynamic config generator function that takes a dict mapping + parameter names to fake tensors, and returns list[CustomOpConfig] + based on input tensor properties. Mutually exclusive with configs. name: Operation name (default: "{op_name}_autotuned") input_gen_fns: Custom input generators for benchmarking Examples: + # Static configs @torch.library.custom_op("mylib::attention", mutates_args=()) def my_attention(query, key, value, head_dim=32): ... @@ -383,6 +429,20 @@ def my_attention(query, key, value, head_dim=32): "value": lambda fake: torch.randn_like(fake, device='cuda'), }, ) + + # Dynamic config generation based on input tensor properties + def generate_k_split_configs(fake_tensors: dict[str, torch.Tensor]) -> list[CustomOpConfig]: + # Access tensor shapes, dtypes, devices, etc. + m, k = fake_tensors["mat1"].shape + _, n = fake_tensors["mat2"].shape + k_splits = ... # compute possible k splits based on tensor properties + return [CustomOpConfig(k_splits=k) for k in k_splits] + + register_custom_op_autotuning( + matmul_decomposeK_op, + config_generator=generate_k_split_configs, + input_gen_fns={...}, + ) """ from torch._library.custom_ops import CustomOpDef @@ -392,23 +452,36 @@ def my_attention(query, key, value, head_dim=32): f"got {type(custom_op)}." ) + # Validate configs and config_generator are mutually exclusive + if configs is not None and config_generator is not None: + raise ValueError( + "Cannot specify both 'configs' and 'config_generator'. " + "Use 'config_generator' for shape-dependent configs." + ) + + if configs is None and config_generator is None: + raise ValueError("Must specify either 'configs' or 'config_generator'") + op_overload = custom_op._opoverload default_impl = custom_op._init_fn - if not isinstance(configs, (list, tuple)): - raise TypeError(f"configs must be a list or tuple, got {type(configs)}") + # Process and validate static configs at registration time + static_configs = None + if configs is not None: + if not isinstance(configs, (list, tuple)): + raise TypeError(f"configs must be a list or tuple, got {type(configs)}") - processed_configs = [] - for cfg in configs: - if isinstance(cfg, CustomOpConfig): - processed_configs.append(cfg) - else: - raise TypeError( - f"Each config must be a CustomOpConfig object, got {type(cfg)}" - ) + static_configs = [] + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + static_configs.append(cfg) + else: + raise TypeError( + f"Each config must be a CustomOpConfig object, got {type(cfg)}" + ) - if not processed_configs: - raise ValueError("At least one config must be provided") + if not static_configs: + raise ValueError("At least one config must be provided") if name is None: name = f"{op_overload._name}_autotuned" @@ -419,11 +492,20 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: # Extract tensor inputs and non-tensor parameters (runtime kwargs) tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs) - # Prepare decompositions and kwargs by merging config params with runtime kwargs + # Get configs: either generate dynamically or use static configs + if config_generator is not None: + configs_to_use = _generate_dynamic_configs( + tensor_inputs, config_generator, default_impl, name + ) + else: + assert static_configs is not None + configs_to_use = static_configs + + # Prepare decompositions and kwargs for autotuning decompositions = [] non_tensor_args = [] - for cfg in processed_configs: + for cfg in configs_to_use: decomp = cfg.get_decomposition(default_impl=default_impl) decompositions.append(decomp) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 1a72e279aab79..d36b8d56cc711 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -7,12 +7,13 @@ import math from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Optional, TYPE_CHECKING, Union import sympy import torch from torch._inductor.virtualized import V +from torch.nn.attention.flex_attention import _Backend from ...ir import ComputedBuffer, ExternKernel, FixedLayout, TensorBox from ...lowering import empty, empty_strided, lowerings, register_lowering @@ -38,6 +39,8 @@ from .flex_decoding import _use_flex_decoding, create_flex_decoding_kernel from .flex_flash_attention import ( _use_flex_flash_attention, + _use_flex_flash_attention_backward, + create_flex_flash_attention_backward_kernel, create_flex_flash_attention_kernel, ) @@ -51,6 +54,17 @@ Expr = sympy.Expr +def _sanitize_kernel_options_for_triton( + kernel_options: dict[str, Any], +) -> tuple[dict[str, Any], _Backend]: + """We always strip quotes around str values, we only need this in lowering, so we pop it here + to avoid passing to triton constexpr dict + """ + sanitized = dict(kernel_options) + backend = cast(_Backend, sanitized.pop("BACKEND", "AUTO")) + return sanitized, backend + + @SymbolicGridFn def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): """How is this kernel parallelized? @@ -93,7 +107,7 @@ def flex_attention( subgraph, block_mask, scale, - kernel_options, + kernel_options: dict[str, Any], score_mod_other_buffers, mask_mod_other_buffers, ): @@ -170,7 +184,7 @@ def flex_attention( ) freeze_irnodes(mask_graph_buffer) - kernel_options = dict(kernel_options) + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v @@ -180,7 +194,19 @@ def flex_attention( enable_gqa = V.graph.sizevars.evaluate_expr( sympy.Ne(query.get_size()[1], key.get_size()[1]), ) - if _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa): + + can_use_decode = _use_flex_decoding( + query, kv_indices, value, kernel_options, enable_gqa + ) + use_decode = (backend == "TRITON_DECODE") or (backend == "AUTO" and can_use_decode) + + if backend == "TRITON_DECODE" and not can_use_decode: + raise RuntimeError( + "BACKEND='TRITON_DECODE' was specified but flex_decoding cannot be used for this input. " + "flex_decoding is only available for short sequence lengths with specific configurations." + ) + + if use_decode: return create_flex_decoding_kernel( query, key, @@ -227,6 +253,7 @@ def flex_attention( mask_graph, kernel_options, num_score_mod_placeholders=len(placeholder_inps), + backend=backend, ): return create_flex_flash_attention_kernel( query, @@ -635,7 +662,7 @@ def flex_attention_backward(*args, **kwargs): f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" ) - kernel_options = dict(kernel_options) + kernel_options, backend = _sanitize_kernel_options_for_triton(kernel_options) # Mark symbols in custom kernel options as static shapes and add guards. kernel_options = { k: V.graph.sizevars.guard_int(v) if isinstance(v, sympy.Symbol) else v @@ -698,6 +725,15 @@ def flex_attention_backward(*args, **kwargs): ) freeze_irnodes(mask_graph_buffer) + if _use_flex_flash_attention_backward( + fw_graph, + mask_graph, + backend=backend, + ): + return create_flex_flash_attention_backward_kernel( + query, key, value, out, logsumexp, grad_out, scale, kernel_options + ) + # Construct layout with stride order matching K key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] key_strides = infer_dense_strides(key_size, key.get_stride()) diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 0d3721aa730a4..05d1290f0ab49 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -5,7 +5,7 @@ import importlib from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Literal, Optional import sympy from sympy import Expr, Integer @@ -42,6 +42,10 @@ def ensure_flash_available() -> bool: flash_attention_cutedsl_template = CuteDSLTemplate( name="flash_attention_cutedsl", source=load_flex_template("flash_attention") ) +flash_attention_backward_cutedsl_template = CuteDSLTemplate( + name="flash_attention_backward_cutedsl", + source=load_flex_template("flash_attention_backward"), +) def _fixed_indexer_cute( @@ -101,6 +105,28 @@ def cutedsl_make_indexer(self): FixedLayout.make_indexer = original_make_indexer # type: ignore[assignment] +def wrap_choice_render_with_cutedsl_indexer(choice: Any) -> None: + """ + Wrap a template choice's kernel render to apply CuteDSL indexer patching. + + See Note [CuteDSL indexer patch]: + This wrapper allows the template to construct its closures normally, then + scopes the indexer patch to the actual render call that emits the kernel. + This ensures CuteDSL templates see colexicographic indexing while preserving + the template's setup logic. + """ + original_make_kernel_render = choice.make_kernel_render + + def make_kernel_render_with_patch(*args, **kwargs): + render_kernel, render = original_make_kernel_render(*args, **kwargs) + # Let the template construct its closures, then scope the indexer patch + # to the actual render call that emits the kernel + render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + return render_kernel, render_with_patch + + choice.make_kernel_render = make_kernel_render_with_patch + + def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" inputs = [] @@ -117,6 +143,18 @@ def requires_grad(n): return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) +def is_trivial_score_graph(graph_module: GraphModule) -> bool: + """Backwards currently doesn't support score_mods, match against identity""" + graph = graph_module.graph + nodes = list(graph.nodes) + placeholders = [n for n in nodes if n.op == "placeholder"] + output = [n for n in nodes if n.op == "output"] + assert len(output) == 1, "Got graph w/ multiple outputs" + output_val = output[0].args[0] + # The identity graph just sends the score straight through + return output_val == placeholders[0] + + def is_trivial_mask_graph(graph_module: GraphModule) -> bool: """Mask graph is trivial when it only gates via the default full op.""" graph = graph_module.graph @@ -133,7 +171,7 @@ def is_trivial_mask_graph(graph_module: GraphModule) -> bool: @functools.lru_cache(maxsize=1) def _supports_nontrivial_mask_graphs() -> bool: """Currently only supported on Hopper (SM90) GPUs.""" - return torch.cuda.get_device_capability()[0] == 9 + return torch.cuda.get_device_capability()[0] in [9, 10] def _can_use_flex_flash_attention( @@ -171,20 +209,34 @@ def _use_flex_flash_attention( mask_graph: Subgraph, kernel_options: dict[str, Any], num_score_mod_placeholders: int, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], ) -> bool: - """Determine if we should use flex flash attention for the given inputs.""" - force_flash = kernel_options.get("force_flash", False) + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False can_use, reason = _can_use_flex_flash_attention( subgraph, mask_graph, num_score_mod_placeholders ) - if force_flash and not can_use: + if not can_use: raise RuntimeError( - f"force_flash=True but flash attention cannot be used: {reason}" + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" ) - return force_flash and can_use + return True def create_flex_flash_attention_kernel( @@ -273,29 +325,167 @@ def create_flex_flash_attention_kernel( NEEDS_BLOCK_MASK=needs_block_mask, ) - def wrap_choice_render(choice): - # See Note [CuteDSL indexer patch] - original_make_kernel_render = choice.make_kernel_render + for choice in choices: + wrap_choice_render_with_cutedsl_indexer(choice) + + if error or not choices: + # Fallback to original implementation + raise RuntimeError(f"CuteDSL template failed: {error}") + + # No autotune for now + template_output = choices[0].output_node() + + return (template_output, lse) - def make_kernel_render_with_patch(*args, **kwargs): - render_kernel, render = original_make_kernel_render(*args, **kwargs) - # Let the template construct its closures, then scope the indexer patch - # to the actual render call that emits the kernel - render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) +def _can_use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, +) -> tuple[bool, str]: + if not ensure_flash_available(): + return False, "CUTE flash attention is not available" + + if not is_trivial_score_graph(fw_subgraph.graph_module): + return ( + False, + "NYI: Flex Flash Attention doesn't support score_mods in bwds yet.", + ) - return render_kernel, render_with_patch + if not is_trivial_mask_graph(mask_graph.graph_module): + return False, "NYI: Flex Flash Attention doesn't support block_sparsity yet." - choice.make_kernel_render = make_kernel_render_with_patch + return True, "" + + +def _use_flex_flash_attention_backward( + fw_subgraph: Subgraph, + mask_graph: Subgraph, + backend: Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"], +) -> bool: + """Determine if we should use flex flash attention for the given inputs. + + Args: + subgraph: The score modification subgraph + mask_graph: The mask modification subgraph + kernel_options: Kernel configuration options + num_score_mod_placeholders: Number of placeholders in score_mod + backend: Implementation selector (AUTO, TRITON, FLASH, TRITON_DECODE) + + Returns: + True if flash attention should be used, False otherwise + """ + # Flash is experimental and must be explicitly requested + if backend != "FLASH": + return False + + can_use, reason = _can_use_flex_flash_attention_backward( + fw_subgraph, + mask_graph, + ) + + if not can_use: + raise RuntimeError( + f"BACKEND='FLASH' but flash attention cannot be used: {reason}" + ) + + return True + + +def create_flex_flash_attention_backward_kernel( + query: TensorBox, + key: TensorBox, + value: TensorBox, + out: TensorBox, + logsumexp: TensorBox, + grad_out: TensorBox, + scale: float, + kernel_options: dict[str, Any], + # TODO: will be needed + # grad_logsumexp, + # fw_graph: SubgraphResults, + # joint_graph: SubgraphResults, + # mask_graph: SubgraphResults, + # score_mod_other_buffers: list[TensorBox], + # mask_mod_other_buffers: list[TensorBox], + # kv_num_blocks: TensorBox | None, + # kv_indices: TensorBox | None, + # full_kv_num_blocks: TensorBox | None, + # full_kv_indices: TensorBox | None, +) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox, TensorBox, tuple]: + """Create a CuteDSL flash attention backward kernel for the default mod path.""" + if not ensure_flash_available(): + raise RuntimeError("CUTE flash attention not available") + + batch_size, num_heads, seq_len_q, head_dim = query.get_size() + v_head_dim = value.get_size()[-1] + device = query.get_device() + dtype = query.get_dtype() + assert device is not None + + grad_query_strides = infer_dense_strides( + [batch_size, num_heads, seq_len_q, head_dim], query.get_stride() + ) + grad_query = empty_strided( + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=grad_query_strides, + dtype=dtype, + device=device, + ) + + grad_key_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], head_dim], key.get_stride() + ) + grad_key = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], head_dim], + stride=grad_key_strides, + dtype=dtype, + device=device, + ) + + grad_value_strides = infer_dense_strides( + [batch_size, num_heads, value.get_size()[2], v_head_dim], value.get_stride() + ) + grad_value = empty_strided( + size=[batch_size, num_heads, value.get_size()[2], v_head_dim], + stride=grad_value_strides, + dtype=dtype, + device=device, + ) + + output_layout = FixedLayout( + device=device, + dtype=dtype, + size=[batch_size, num_heads, seq_len_q, head_dim], + stride=[sympy.sympify(s) for s in grad_query.get_stride()], + ) + + choices: list[Any] = [] + + input_nodes = [ + query, + key, + value, + out, + grad_out, + logsumexp, + grad_key, + grad_value, + ] + + error = flash_attention_backward_cutedsl_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=output_layout, + mutated_inputs=[grad_key, grad_value], + SM_SCALE=scale, + ) for choice in choices: - wrap_choice_render(choice) + wrap_choice_render_with_cutedsl_indexer(choice) if error or not choices: - # Fallback to original implementation raise RuntimeError(f"CuteDSL template failed: {error}") - # No autotune for now template_output = choices[0].output_node() - return (template_output, lse) + return (template_output, grad_key, grad_value, tuple()) diff --git a/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja b/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja new file mode 100644 index 0000000000000..2831ba6af5b60 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flash_attention_backward.py.jinja @@ -0,0 +1,28 @@ +{{def_kernel("Q", "K", "V", "OUT", "D_OUT", "LSE", "DK", "DV")}} + from flash_attn.cute.interface import _flash_attn_bwd + + q_transposed = Q.transpose(1, 2) + k_transposed = K.transpose(1, 2) + v_transposed = V.transpose(1, 2) + out_transposed = OUT.transpose(1, 2) + d_out_transposed = D_OUT.transpose(1, 2) + + dq_transposed, dk_transposed, dv_transposed = _flash_attn_bwd( + q_transposed, + k_transposed, + v_transposed, + out_transposed, + d_out_transposed, + LSE, + softmax_scale={{SM_SCALE}}, + ) + + dq = dq_transposed.transpose(1, 2) + dk = dk_transposed.transpose(1, 2) + dv = dv_transposed.transpose(1, 2) + + dq_out = {{get_output()}} + {# TODO: add out support to flash #} + dq_out.copy_(dq) + DK.copy_(dk) + DV.copy_(dv) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 986ceb4405a14..5b57c458f46e6 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -55,6 +55,7 @@ ) from .mm_common import ( _is_static_problem, + load_kernel_template, mm_args, mm_grid, persistent_mm_grid, @@ -75,162 +76,18 @@ aten = torch.ops.aten prims = torch.ops.prims +# We define each template kernel in a separate file which is the name of the input to load_kernel_template +# (e.g. triton_mm for templates/triton_mm.py.jinja). +# If you are adding a new template, please follow that pattern and add a new file with your implementation in the templates folder. mm_template = TritonTemplate( name="mm", grid=mm_grid, - source=( - r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): - offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - else: - offs_a_m = rm % M - if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): - offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - else: - offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for k_idx in range(0, tl.cdiv(K, BLOCK_K)): - {% if not EVEN_K %} - a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) - b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) - {% endif %} - a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) - b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) - - idx_m = offs_a_m[:, None] - idx_n = a_k_idx_vals - {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", - indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} - - idx_m = b_k_idx_vals - idx_n = offs_b_n[None, :] - {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", - indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} - - {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% endif %} - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} -""" - if (torch.version.hip is None) or triton_version >= "3.3.0" - # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 - # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. - # See more details in https://github.com/pytorch/pytorch/pull/146293 - else r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): - offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - else: - offs_a_m = rm % M - if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): - offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - else: - offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - for k_idx in range(0, tl.cdiv(K, BLOCK_K)): - {% if not EVEN_K %} - a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) - b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) - {% endif %} - a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) - b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) - - idx_m = offs_a_m[:, None] - idx_n = a_k_idx_vals - {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", - indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} - - idx_m = b_k_idx_vals - idx_n = offs_b_n[None, :] - {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", - indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} - {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% endif %} - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} -""" - ), + source=load_kernel_template("triton_mm") + if (torch.version.hip is None) or triton_version >= "3.3.0" + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. + # See more details in https://github.com/pytorch/pytorch/pull/146293 + else load_kernel_template("triton_mm_rocm"), cache_codegen_enabled_for_template=True, prologue_loads_all_inputs=True, ) @@ -238,682 +95,27 @@ persistent_tma_mm_template = TritonTemplate( name="mm_persistent_tma", grid=persistent_mm_grid, - source=r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - start_pid = tl.program_id(0).to(INDEX_DTYPE) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = grid_m * grid_n - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - width = GROUP_M * grid_n - rk_for_mask = tl.arange(0, BLOCK_K) - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - - {%- if TMA_EXPERIMENTAL_API %} - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - global_size=[M, K] if A_ROW_MAJOR else [K, M], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - global_size=[K, N] if B_ROW_MAJOR else [N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - {%- else %} - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K] if A_ROW_MAJOR else [K, M], - strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], - block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[K, N] if B_ROW_MAJOR else [N, K], - strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], - block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - ) - {%- endif %} - - pid_m = 0 - pid_n = 0 - rm = 0 - rn = 0 - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - # re-order program ID for better L2 performance - group_id = tile_id // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (tile_id % group_size) - pid_n = (tile_id % width) // (group_size) - - rm = pid_m * BLOCK_M - rn = pid_n * BLOCK_N - - rk = ki * BLOCK_K - - {%- if TMA_EXPERIMENTAL_API %} - a = tl._experimental_descriptor_load( - a_desc_ptr, - [rm, rk] if A_ROW_MAJOR else [rk, rm], - [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - A.dtype.element_ty, - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, - [rk, rn] if B_ROW_MAJOR else [rn, rk], - [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - B.dtype.element_ty, - ) - {%- else %} - a = tl.load_tensor_descriptor( - a_desc, - [rm, rk] if A_ROW_MAJOR else [rk, rm], - ) - b = tl.load_tensor_descriptor( - b_desc, - [rk, rn] if B_ROW_MAJOR else [rn, rk], - ) - {%- endif %} - acc += tl.dot( - a if A_ROW_MAJOR else a.T, - b if B_ROW_MAJOR else b.T, - allow_tf32=ALLOW_TF32, - ) - - if ki == k_tiles - 1: - # inductor generates a suffix - {%- if TMA_EXPERIMENTAL_API %} - # rematerialize rm and rn to save registers - rcm = rm + tl.arange(0, BLOCK_M) - rcn = rn + tl.arange(0, BLOCK_N) - idx_m = rcm[:, None] - idx_n = rcn[None, :] - mask = (idx_m < M) & (idx_n < N) - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} - {%- else %} - {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} - {%- endif %} - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - -""", + source=load_kernel_template("triton_persistent_tma_mm"), ) -load_scales = r""" -@triton.jit -def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): - if SCALE_RECIPE == 0: - return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values - else: - return scale_ptr # For all other scaling recipes, we'll return the pointers -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALE_RECIPE_A: tl.constexpr, - SCALE_RECIPE_B: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: # (ScalingType.TensorWise, ScalingType.TensorWise) - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -scaled_mm_device_tma_epilogue_scaling = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALE_RECIPE_A == 1: # ScalingType.RowWise - stride_a_scale_m = 1 - else: - stride_a_scale_m = 0 - - if SCALE_RECIPE_B == 1: # ScalingType.RowWise - stride_b_scale_n = 1 - else: - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - {%- if TMA_EXPERIMENTAL_API %} - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - {%- else %} - stride_am = {{stride("A", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K], - strides=[stride_am, 1], - block_shape=[BLOCK_M, BLOCK_K], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[N, K], - strides=[stride_bn, 1], - block_shape=[BLOCK_N, BLOCK_K], - ) - {%- endif %} - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) - b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - {%- if TMA_EXPERIMENTAL_API %} - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - {%- else %} - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) - {%- endif %} - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALE_RECIPE_A, - SCALE_RECIPE_B, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - # inductor generates a suffix - {%- if TMA_EXPERIMENTAL_API %} - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} - {%- else %} - {{store_output( - ("offs_am", "offs_bn"), - "accumulator", - indent_width=12, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True, - )}} - {%- endif %} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( name="scaled_mm_device_tma_epilogue_scaling", grid=persistent_mm_grid, - source=scaled_mm_device_tma_epilogue_scaling + load_scales + apply_scaling, + source=load_kernel_template("triton_epilogue_scaled_mm"), ) -blockwise1xTILESIZE_scaling = r""" -@triton.jit -def blockwise1xTILESIZE_scaling( - pid, - scale, - ki, - lhs_size, - lhs_blocks, - k_blocks, - BLOCK_lhs: tl.constexpr, - BLOCK_K: tl.constexpr, - MIN_BLOCK_TILE_K: tl.constexpr, - TILE_SIZE: tl.constexpr, -): - row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) - col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) - ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] - mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) - scale_block = tl.load(ptrs, mask=mask, other=1.0) - - scale_expanded = scale_block[:, :, None] - scale_expanded = tl.broadcast_to( - scale_expanded, - (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) - ) - scale_expanded = scale_expanded.reshape( - BLOCK_lhs, - ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K - ) - - return scale_expanded -""" - -blockwise128x128_scaling = r""" -@triton.jit -def blockwise128x128_scaling( - pid, - scale, - ki, - lhs_blocks, - k_blocks, - BLOCK_lhs: tl.constexpr, - BLOCK_K: tl.constexpr, - MIN_BLOCK_TILE_lhs: tl.constexpr, - MIN_BLOCK_TILE_K: tl.constexpr, -): - row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) - col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) - ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] - mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) - scale_block = tl.load(ptrs, mask=mask, other=1.0) - - scale_expanded = scale_block[:, :, None, None] - scale_expanded = tl.broadcast_to( - scale_expanded, - ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) - ) - scale_expanded = scale_expanded.reshape( - ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, - ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K - ) - - return scale_expanded -""" - -scaled_mm_device_tma_main_loop_scaling = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_bn = {{stride("B", 1)}} - - start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K], - strides=[stride_am, 1], - block_shape=[BLOCK_M, BLOCK_K], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[N, K], - strides=[stride_bn, 1], - block_shape=[BLOCK_N, BLOCK_K], - ) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) - b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) - - am_blocks = tl.cdiv(M, TILE_SIZE_A) - ak_blocks = tl.cdiv(K, TILE_SIZE_A) - bn_blocks = tl.cdiv(N, TILE_SIZE_B) - bk_blocks = tl.cdiv(K, TILE_SIZE_B) - - {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 - scale_a_block = blockwise128x128_scaling( - pid_m, - a_scale, - ki, - am_blocks, - ak_blocks, - BLOCK_M, - BLOCK_K, - MIN_BLOCK_TILE_AM, - MIN_BLOCK_TILE_AK, - ) - {%- else %} # ScalingType.Blockwise1xTILESIZE - scale_a_block = blockwise1xTILESIZE_scaling( - pid_m, - a_scale, - ki, - M, - am_blocks, - ak_blocks, - BLOCK_M, - BLOCK_K, - MIN_BLOCK_TILE_AK, - TILE_SIZE_A, - ) - {%- endif %} - - {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 - scale_b_block = blockwise128x128_scaling( - pid_n, - b_scale, - ki, - bn_blocks, - bk_blocks, - BLOCK_N, - BLOCK_K, - MIN_BLOCK_TILE_BN, - MIN_BLOCK_TILE_BK, - ) - {%- else %} # ScalingType.Blockwise1xTILESIZE - scale_b_block = blockwise1xTILESIZE_scaling( - pid_n, - b_scale, - ki, - N, - bn_blocks, - bk_blocks, - BLOCK_N, - BLOCK_K, - MIN_BLOCK_TILE_BK, - TILE_SIZE_B, - ) - {%- endif %} - - a_scaled = a * scale_a_block - b_scaled = b * scale_b_block - accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) - - if ki == k_tiles - 1: - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - - # inductor generates a suffix - {{store_output( - ("offs_am", "offs_bn"), - "accumulator", - indent_width=12, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True, - )}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" scaled_mm_device_tma_main_loop_scaling_template = TritonTemplate( name="scaled_mm_device_tma_main_loop_scaling", grid=persistent_mm_grid, - source=scaled_mm_device_tma_main_loop_scaling - + load_scales - + blockwise1xTILESIZE_scaling - + blockwise128x128_scaling, + source=load_kernel_template("triton_main_loop_scaled_mm"), ) -_compute_blackwell_pid = r""" -@triton.jit -def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - GROUP_M = min(grid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % GROUP_M) - pid_n = (tile_id % num_pid_in_group) // GROUP_M - return pid_m, pid_n -""" - -_blackwell_ws_persistent_device_tma = r""" -{{def_kernel("A", "B")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - start_pid = tl.program_id(0) - grid_m = tl.cdiv(M, BLOCK_M) - grid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = grid_m * grid_n - - # Note: We require TMA_EXPERIMENTAL_API == False, which - # we will check before invoking this template. - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - a_desc = triton.language.make_tensor_descriptor( - base=A, - shape=[M, K] if A_ROW_MAJOR else [K, M], - strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], - block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], - ) - b_desc = triton.language.make_tensor_descriptor( - base=B, - shape=[K, N] if B_ROW_MAJOR else [N, K], - strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], - block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], - ) - - # tile_id_c is used in the epilogue to break the dependency between - # the prologue and the epilogue - tile_id_c = start_pid - NUM_SMS - num_pid_in_group = GROUP_M * grid_n - - for tile_id in tl.range( - start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE - ): - pid_m, pid_n = _compute_pid( - tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS - ) - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for ki in range(k_tiles): - offs_k = ki * BLOCK_K - a = tl.load_tensor_descriptor( - a_desc, - [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], - ) - b = tl.load_tensor_descriptor( - b_desc, - [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], - ) - accumulator += tl.dot( - a if A_ROW_MAJOR else a.T, - b if B_ROW_MAJOR else b.T, - allow_tf32=ALLOW_TF32, - ) - - tile_id_c += NUM_SMS - pid_m, pid_n = _compute_pid( - tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS - ) - offs_cm = pid_m * BLOCK_M - offs_cn = pid_n * BLOCK_N - {%- if EPILOGUE_SUBTILE %} - tl.static_assert(BLOCK_N % 2 == 0) - acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) - acc = tl.permute(acc, (0, 2, 1)) - acc0, acc1 = tl.split(acc) - {{store_output( - ("offs_cm", "offs_cn"), - "acc0", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N // 2"), - block_indexing=True - )}} - offs_cn2 = offs_cn + BLOCK_N // 2 - {{store_output( - ("offs_cm", "offs_cn2"), - "acc1", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N // 2"), - block_indexing=True - )}} - {%- else %} - {{store_output( - ("offs_cm", "offs_cn"), - "accumulator", - indent_width=8, - val_shape=("BLOCK_M", "BLOCK_N"), - block_indexing=True - )}} - {%- endif %} -""" - blackwell_ws_persistent_device_tma_mm_template = TritonTemplate( name="blackwell_ws_persistent_device_tma", grid=persistent_mm_grid, - source=_blackwell_ws_persistent_device_tma + _compute_blackwell_pid, + source=load_kernel_template("triton_blackwell_ws_persistent_device_tma_mm"), ) diff --git a/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja new file mode 100644 index 0000000000000..34ff2d69793c0 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja @@ -0,0 +1,107 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + + # Note: We require TMA_EXPERIMENTAL_API == False, which + # we will check before invoking this template. + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + + # tile_id_c is used in the epilogue to break the dependency between + # the prologue and the epilogue + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = GROUP_M * grid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=FLATTEN, warp_specialize=WARP_SPECIALIZE + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_K + a = tl.load_tensor_descriptor( + a_desc, + [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], + ) + b = tl.load_tensor_descriptor( + b_desc, + [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], + ) + accumulator += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid( + tile_id_c, num_pid_in_group, grid_m, GROUP_M, NUM_SMS + ) + offs_cm = pid_m * BLOCK_M + offs_cn = pid_n * BLOCK_N + {%- if EPILOGUE_SUBTILE %} + tl.static_assert(BLOCK_N % 2 == 0) + acc = tl.reshape(accumulator, (BLOCK_M, 2, BLOCK_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + {{store_output( + ("offs_cm", "offs_cn"), + "acc0", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + offs_cn2 = offs_cn + BLOCK_N // 2 + {{store_output( + ("offs_cm", "offs_cn2"), + "acc1", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N // 2"), + block_indexing=True + )}} + {%- else %} + {{store_output( + ("offs_cm", "offs_cn"), + "accumulator", + indent_width=8, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True + )}} + {%- endif %} + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, grid_m, GROUP_M: tl.constexpr, NUM_SMS: tl.constexpr): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + GROUP_M = min(grid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % GROUP_M) + pid_n = (tile_id % num_pid_in_group) // GROUP_M + return pid_m, pid_n diff --git a/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja new file mode 100644 index 0000000000000..56ef18b7a91e3 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja @@ -0,0 +1,194 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALE_RECIPE_A == 1: # ScalingType.RowWise + stride_a_scale_m = 1 + else: + stride_a_scale_m = 0 + + if SCALE_RECIPE_B == 1: # ScalingType.RowWise + stride_b_scale_n = 1 + else: + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A, + SCALE_RECIPE_B, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + {%- endif %} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALE_RECIPE_A: tl.constexpr, + SCALE_RECIPE_B: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALE_RECIPE_A == 1 and SCALE_RECIPE_B == 1: # (ScalingType.RowWise, ScalingType.RowWise) + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: # (ScalingType.TensorWise, ScalingType.TensorWise) + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale diff --git a/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja new file mode 100644 index 0000000000000..171340a2c9233 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja @@ -0,0 +1,212 @@ +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_bn = {{stride("B", 1)}} + + start_pid = tl.program_id(axis=0).to(INDEX_DTYPE) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[stride_am, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[stride_bn, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale = load_scales(A_inverse_scale, SCALE_RECIPE_A) + b_scale = load_scales(B_inverse_scale, SCALE_RECIPE_B) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + + am_blocks = tl.cdiv(M, TILE_SIZE_A) + ak_blocks = tl.cdiv(K, TILE_SIZE_A) + bn_blocks = tl.cdiv(N, TILE_SIZE_B) + bk_blocks = tl.cdiv(K, TILE_SIZE_B) + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_a_block = blockwise128x128_scaling( + pid_m, + a_scale, + ki, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AM, + MIN_BLOCK_TILE_AK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_a_block = blockwise1xTILESIZE_scaling( + pid_m, + a_scale, + ki, + M, + am_blocks, + ak_blocks, + BLOCK_M, + BLOCK_K, + MIN_BLOCK_TILE_AK, + TILE_SIZE_A, + ) + {%- endif %} + + {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + scale_b_block = blockwise128x128_scaling( + pid_n, + b_scale, + ki, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BN, + MIN_BLOCK_TILE_BK, + ) + {%- else %} # ScalingType.Blockwise1xTILESIZE + scale_b_block = blockwise1xTILESIZE_scaling( + pid_n, + b_scale, + ki, + N, + bn_blocks, + bk_blocks, + BLOCK_N, + BLOCK_K, + MIN_BLOCK_TILE_BK, + TILE_SIZE_B, + ) + {%- endif %} + + a_scaled = a * scale_a_block + b_scaled = b * scale_b_block + accumulator = tl.dot(a_scaled, b_scaled.T, accumulator) + + if ki == k_tiles - 1: + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + + # inductor generates a suffix + {{store_output( + ("offs_am", "offs_bn"), + "accumulator", + indent_width=12, + val_shape=("BLOCK_M", "BLOCK_N"), + block_indexing=True, + )}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@triton.jit +def load_scales(scale_ptr, SCALE_RECIPE: tl.constexpr): + if SCALE_RECIPE == 0: + return tl.load(scale_ptr) # For tensor-wise scaling, we'll load the scalar values + else: + return scale_ptr # For all other scaling recipes, we'll return the pointers + + +@triton.jit +def blockwise1xTILESIZE_scaling( + pid, + scale, + ki, + lhs_size, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, + TILE_SIZE: tl.constexpr, +): + row_offs_scale = pid * BLOCK_lhs + tl.arange(0, BLOCK_lhs) + col_offs_scale = ki * tl.cdiv(BLOCK_K, TILE_SIZE) + tl.arange(0, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_size) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + (BLOCK_lhs, (BLOCK_K + TILE_SIZE - 1) // TILE_SIZE, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + BLOCK_lhs, + ((BLOCK_K + TILE_SIZE - 1) // TILE_SIZE) * MIN_BLOCK_TILE_K + ) + + return scale_expanded + + +@triton.jit +def blockwise128x128_scaling( + pid, + scale, + ki, + lhs_blocks, + k_blocks, + BLOCK_lhs: tl.constexpr, + BLOCK_K: tl.constexpr, + MIN_BLOCK_TILE_lhs: tl.constexpr, + MIN_BLOCK_TILE_K: tl.constexpr, +): + row_offs_scale = pid * tl.cdiv(BLOCK_lhs, 128) + tl.arange(0, (BLOCK_lhs + 128 - 1) // 128) + col_offs_scale = ki * tl.cdiv(BLOCK_K, 128) + tl.arange(0, (BLOCK_K + 128 - 1) // 128) + ptrs = scale + row_offs_scale[:, None] * k_blocks + col_offs_scale[None, :] + mask = (row_offs_scale[:, None] < lhs_blocks) & (col_offs_scale[None, :] < k_blocks) + scale_block = tl.load(ptrs, mask=mask, other=1.0) + + scale_expanded = scale_block[:, :, None, None] + scale_expanded = tl.broadcast_to( + scale_expanded, + ((BLOCK_lhs + 128 - 1) // 128, (BLOCK_K + 128 - 1) // 128, MIN_BLOCK_TILE_lhs, MIN_BLOCK_TILE_K) + ) + scale_expanded = scale_expanded.reshape( + ((BLOCK_lhs + 128 - 1) // 128) * MIN_BLOCK_TILE_lhs, + ((BLOCK_K + 128 - 1) // 128) * MIN_BLOCK_TILE_K + ) + + return scale_expanded diff --git a/torch/_inductor/kernel/templates/triton_mm.py.jinja b/torch/_inductor/kernel/templates/triton_mm.py.jinja new file mode 100644 index 0000000000000..2da348f3e767c --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_mm.py.jinja @@ -0,0 +1,72 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja new file mode 100644 index 0000000000000..42b99c70d5cbd --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja @@ -0,0 +1,71 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja new file mode 100644 index 0000000000000..38fe092c25780 --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja @@ -0,0 +1,129 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + {%- if TMA_EXPERIMENTAL_API %} + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + {%- else %} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[stride_am, 1] if A_ROW_MAJOR else [stride_ak, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[stride_bk, 1] if B_ROW_MAJOR else [stride_bn, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # inductor generates a suffix + {%- if TMA_EXPERIMENTAL_API %} + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"))}} + {%- else %} + {{store_output(("rm", "rn"), "acc", indent_width=12, val_shape=("BLOCK_M", "BLOCK_N"), block_indexing=True)}} + {%- endif %} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d374be59c9446..d9890f1958edd 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3918,15 +3918,6 @@ def _unsafe_index_put_(self, indices, values, accumulate=False): def index_put_impl_(self, indices, values, accumulate, check, may_realize=False): if may_realize: - def try_get_name(x): - if isinstance(x, ir.TensorBox): - x = x.data - if isinstance(x, ir.BaseView): - x = x.unwrap_view() - if isinstance(x, ir.StorageBox): - x = x.data - return x.get_name() if isinstance(x, ir.Buffer) else None - def indice_slice_from_randperm(indice): # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660 # For this specific pattern, indices is unique as coming from torch.randperm. @@ -3941,7 +3932,7 @@ def indice_slice_from_randperm(indice): ) return False - if try_get_name(self) in values.get_read_names() and not all( + if ir.try_get_name(self) in values.get_read_names() and not all( indice_slice_from_randperm(indice) for indice in indices ): # Fix issue: https://github.com/pytorch/pytorch/issues/138908 @@ -6370,7 +6361,7 @@ def pow_native(a, b): @register_lowering(aten.pow, broadcast=True) def pow(a, b): - if isinstance(b, float) and b == int(b): + if isinstance(b, float) and math.isfinite(b) and b == int(b): return pow(a, int(b)) elif isinstance(b, float) and b == 0.5: return sqrt(a) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 169e105d10b03..b4e66378e85ae 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -47,6 +47,12 @@ def next_power_of_2(n: int) -> int: return n +def last_power_of_2(n: int) -> int: + """Return the largest power of 2 less than or equal to n""" + next_pow2 = next_power_of_2(n) + return next_pow2 // 2 if next_pow2 > n else next_pow2 + + def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: """ Return the total number of bytes the arguments of tensor type takes. diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d5851eeceeb24..175bf76bfc740 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2441,6 +2441,7 @@ def triton_config_reduction( waves_per_eu=None, dynamic_scale_rblock=True, reduction_hint=None, + min_num_warps=None, ) -> Config: """ Construct a reduction triton config with some adjustment heuristics @@ -2476,7 +2477,12 @@ def total_numel() -> int: num_warps = total_numel() // 128 max_num_warps = 16 if r <= 8192 else 32 - num_warps = _num_warps( + if min_num_warps is not None: + _num_warps_func = functools.partial(_num_warps, min_num_warps=min_num_warps) + else: + _num_warps_func = _num_warps + + num_warps = _num_warps_func( num_warps, max_num_warps=max_num_warps, register_intensive=register_intensive ) @@ -3291,9 +3297,6 @@ def _persistent_reduction_configs( ): xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) - loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get( - "num_store", 0 - ) MAX_PERSISTENT_BLOCK_NUMEL = 4096 @@ -3364,12 +3367,11 @@ def _persistent_reduction_configs( # TODO(jansel): we should be able to improve these heuristics elif not max_autotune_enabled: # Do not filter configs when tuning if reduction_hint == ReductionHint.INNER and rnumel >= 256: - if rnumel > 1024: + if rnumel > 1024 or xnumel // 8 < 128 or inductor_meta.get("RSPLIT_SIZE"): configs = configs[:1] else: - x_block = 8 - if xnumel // x_block < 128 or loads_and_stores >= 5: - x_block = 1 + num_warps, min_num_warps = 1, 1 + x_block = min(1024 // rnumel, 8) configs = [ triton_config_reduction( @@ -3377,6 +3379,9 @@ def _persistent_reduction_configs( x_block, rnumel, register_intensive=True, + num_warps=num_warps, + min_num_warps=min_num_warps, + reduction_hint=reduction_hint, ) ] @@ -3426,21 +3431,23 @@ def persistent_reduction( if inductor_meta.get("RSPLIT_SIZE"): new_configs = [] + rsplit_size = inductor_meta.get("RSPLIT_SIZE") + rnumel_hint = size_hints["r0_"] + min_x_block = 1 + if rnumel_hint <= 512: + min_x_block = 4 + x_block = min(max(rsplit_size // 32, min_x_block), 16) for c in configs: - c.kwargs["RSPLIT_SIZE"] = inductor_meta.get("RSPLIT_SIZE") - - c.kwargs["NUM_STAGES"] = 1 - + c.kwargs["RSPLIT_SIZE"] = rsplit_size # small XBLOCK to use less registers/smem - c.kwargs["XBLOCK"] = ( - torch._inductor.config.triton.mix_order_reduction_initial_xblock - ) + c.kwargs["XBLOCK"] = x_block - rnumel_hint = size_hints["r0_"] + num_iters = rsplit_size // x_block + c.kwargs["NUM_STAGES"] = min(max(num_iters // 4, 1), 3) if rnumel_hint <= 1024: c.num_warps //= 2 - c.num_warps = max(c.num_warps, 2) + c.num_warps = max(c.num_warps, 1) new_configs.append(c) # less warps so potentially each sm can run more thread blocks @@ -3621,13 +3628,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b7f36aa306a43..e5bd34ea977e7 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -253,19 +253,23 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: # small workload. When a workload is small enough, data can be # fully cached by L2 size_thres = 5 * 2**20 - if not V.graph.sizevars.statically_known_geq(nrow * ncol, size_thres): + + # Call evaluate_expr rather than statically_known_geq since nrow can + # have dynamic shape in real models. + # Don't use hint directly since hint can be non-representative. + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow * ncol, size_thres)): return False # We require more more row than columns since # 1, we prefer doing persistent reduction for each row # 2, we will split the reduction across the rows - if not V.graph.sizevars.statically_known_geq(nrow, ncol * 2): + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, ncol * 2)): return False # When nrow is small, ncol should also be small (due to the check # above). Thus the entire tensor should be well cached in L2. # Mix order reduction is less beneficial. - if not V.graph.sizevars.statically_known_geq(nrow, 4096): + if not V.graph.sizevars.evaluate_expr(sympy.Ge(nrow, 4096)): return False contiguous_node, other_node = ( @@ -301,6 +305,8 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: return False # rnumel so large that we will not generated persistent reduction + # We don't see real use cases with dynamic ncol. But if we do, + # we should call evaluete_expr here which adds guards. if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): return False @@ -2714,12 +2720,22 @@ def _init(self, nodes: list[ir.Operation]) -> None: if ( used_non_deterministic_runtime_estimations() and config_comms.runtime_estimations_align_across_all_distributed_ranks - ): - from .comms import ( - align_runtime_estimations_across_all_distributed_ranks, + and ( + config.runtime_estimations_mms_benchmark + or config_comms.runtime_estimations_use_nccl_lib_estimations ) + ): + has_collectives = False + for node in self.nodes: + if is_collective(node.node): + has_collectives = True + break + if has_collectives: + from .comms import ( + align_runtime_estimations_across_all_distributed_ranks, + ) - align_runtime_estimations_across_all_distributed_ranks(self.nodes) + align_runtime_estimations_across_all_distributed_ranks(self.nodes) from torch._logging import trace_structured @@ -2742,8 +2758,11 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.process_grouped_nodes() if ( + # pyrefly: ignore[unbound-name] config.graph_partition + # pyrefly: ignore[unbound-name] and config.triton.cudagraphs + # pyrefly: ignore[unbound-name] and config.triton.reorder_for_reducing_graph_partitions ): self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes) @@ -2755,6 +2774,7 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.insert_memory_check_nodes() log_ir_post_fusion(self.nodes) + # pyrefly: ignore[unbound-name] V.debug.graph_diagram(self.nodes) self.debug_draw_graph() diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d6893b07ee3d9..493ca1179fad8 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -777,7 +777,7 @@ def stride(self, name, index=None): val = self.output_node.get_stride() else: assert isinstance(name, str) - val = self.named_input_nodes[name].get_stride() + val = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) if isinstance(index, int): return texpr(self.rename_indexing(val[index])) @@ -955,7 +955,6 @@ def load_input( self.template_mask = mask if mask is not None else "None" self.template_out_shape = index_shape if index_shape else "xindex" self.template_indices = indices - self.named_input_nodes[input_name].data.freeze_layout() self.cse.invalidate(OrderedSet()) template_mask = self.template_mask @@ -1412,7 +1411,7 @@ def make_load(self, name, indices, mask): assert isinstance(indices, (list, tuple)) assert isinstance(name, str) assert isinstance(mask, str) - stride = self.named_input_nodes[name].get_stride() + stride = self.get_stride_and_maybe_freeze_layout(self.named_input_nodes[name]) indices = list(map(OpOverrides.paren, indices)) assert len(indices) == len(stride) index = " + ".join( @@ -1502,6 +1501,10 @@ def kernel_benchmark_extra_args(self) -> list[str]: ) ] + def get_stride_and_maybe_freeze_layout(self, node) -> list[int]: + node.data.freeze_layout() + return node.get_stride() + @functools.cache def _jinja2_env(): diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index e0a87ebac3d87..aa1b4d2db025d 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -121,7 +121,6 @@ def call_function( raise SubgraphLoweringException( f"{target} not supported in subgraph, (missing lowering)" ) - return lowerings[target](*args, **kwargs) def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override] diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5a629b371c766..2ed88a4ec2344 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3741,6 +3741,7 @@ def kai_roundup(a: int, b: int) -> int: def get_kai_packed_weight_size(n_bits, N, K, groupsize): if n_bits == 4: + # Works for both fp32 and bf16 Kernels if groupsize == K: # channelwise # dotprod params only [1x8x32_neon_dotprod] kai_nr = 8 @@ -3870,6 +3871,8 @@ def meta__dyn_quant_pack_4bit_weight( ) return weights.new_empty(int(packed_weight_size), dtype=torch.uint8) packed_weight_size = weights.numel() + scales_zeros.numel() + if bias is not None: + packed_weight_size += bias.numel() return weights.new_empty(packed_weight_size, dtype=torch.float) @@ -3883,8 +3886,12 @@ def meta__dyn_quant_matmul_4bit( ): torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") torch._check( - inp.dtype == torch.float32, - lambda: f"expected input to be f32, got {inp.dtype}", + (inp.dtype == torch.float32) + or (inp.dtype == torch.bfloat16 and block_size == in_features), + lambda: ( + f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), " + f"got {inp.dtype} with block_size={block_size} and in_features={in_features}" + ), ) M = inp.size(0) return inp.new_empty(M, out_features, dtype=inp.dtype) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 9ba46e8c5310c..019e9c59f2423 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1895,8 +1895,6 @@ def set_correction( # NB: we don't actually support symint here, but it's harmless to accept if not isinstance(correction, (IntLike, FloatLike)): raise ValueError("correction argument should be integer or float") - if correction < 0: - raise ValueError("correction argument should be non-negative") return sym_float(correction) diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index f4281674bd118..4d194f773f859 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -269,12 +269,24 @@ def matrix_norm( max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + def _max_min_wrapper(A, dim): + # pyrefly: ignore [unsupported-operation] + if A.size(dim) == 0 and ord > 0.0: + new_size = list(A.size()) + if keepdim: + new_size[dim] = 1 + else: + del new_size[dim] + return torch.zeros(new_size, dtype=A.dtype, device=A.device) + else: + return max_min(A, dim) + if abs_ord == 2.0: if dtype is not None: A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] # pyrefly: ignore [index-error] perm = _backshift_permutation(dim[0], dim[1], A.ndim) - result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + result = _max_min_wrapper(svdvals(prims.transpose(A, perm)), dim=-1) if keepdim: inv_perm = _inverse_permutation(perm) result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) @@ -286,7 +298,7 @@ def matrix_norm( dim0, dim1 = dim1, dim0 if not keepdim and (dim0 < dim1): dim1 -= 1 - return max_min( + return _max_min_wrapper( vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 ) diff --git a/torch/_subclasses/complex_tensor/__init__.py b/torch/_subclasses/complex_tensor/__init__.py new file mode 100644 index 0000000000000..1ab4a816261dc --- /dev/null +++ b/torch/_subclasses/complex_tensor/__init__.py @@ -0,0 +1,9 @@ +from ._core import ComplexTensor +from ._ops import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"] + +ComplexTensor.__module__ = __name__ +ComplexTensorMode.__module__ = __name__ +is_complex_tensor.__module__ = __name__ diff --git a/torch/_subclasses/complex_tensor/_core.py b/torch/_subclasses/complex_tensor/_core.py new file mode 100644 index 0000000000000..edd7568b2ef06 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_core.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +from torch import Tensor +from torch.autograd import Function + + +if TYPE_CHECKING: + from torch._ops import OpOverload + from torch._prims_common import DeviceLikeType + from torch.autograd.function import FunctionCtx + + +class ComplexTensor(Tensor): + """A class that decomposes all ops on complex Tensors into their real and imaginary parts.""" + + _re: Tensor + _im: Tensor + + def __new__(cls, real: Tensor, imag: Tensor) -> Self: + """Initialize a ComplexTensor from its real and imaginary parts.""" + from ._ops.common import REAL_TO_COMPLEX + + shape = real.shape + device = real.device + + # TODO (hameerabbasi): `torch.compile` sometimes fails here without making these + # contiguous. Why? + real = real.contiguous() + imag = imag.contiguous() + + # TODO (hameerabbasi): + # What should we do with dtype? + # We could convert to the complex type (float32 -> complex64), but we + # can't use that model for say `bfloat16` which does not have a + # corresponding complex dtype. + # If we want to support this complex rep using any float type (see + # https://github.com/pytorch/pytorch/issues/95100) + # We either need to: + # 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere + # else. + # 2) We use the real float dtype here, and it is up to the user to know + # that dtype=float here really means complex<2xSize> with dtype + # matching that of re/im parts alone + # I'm going with 1 for now, so that I can make gradcheck and some complex + # ops work properly, but might want to discuss this in the RFP. + dtype = REAL_TO_COMPLEX.get(real.dtype) + if dtype is None: + raise TypeError( + "Unsupported dtype for constituent tensors. Supported dtypes are: " + f"{set(REAL_TO_COMPLEX.keys())!r}." + ) + storage_offset = real.storage_offset() + strides = real.stride() + layout = real.layout + pin_memory = real.is_pinned() + + assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}" + assert device == imag.device, ( + f"Expected imag device {device}, got {imag.device}" + ) + assert real.dtype == imag.dtype, ( + f"Expected imag dtype {real.dtype}, got {imag.dtype}" + ) + assert pin_memory == imag.is_pinned(), ( + f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}" + ) + + res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=device, + dtype=dtype, + storage_offset=storage_offset, + strides=strides, + pin_memory=pin_memory, + layout=layout, + requires_grad=False, + ) + res._re = real.clone().detach() + res._im = imag.clone().detach() + + return res + + @property + def re(self) -> Tensor: + return self._re + + @property + def im(self) -> Tensor: + return self._im + + @classmethod + def __torch_dispatch__( + cls, + func: OpOverload, + types: tuple[type, ...], + args: tuple = (), + kwargs: dict | None = None, + ): + from ._ops.common import lookup_complex + + kwargs = {} if kwargs is None else kwargs + + impl = lookup_complex(func, *args, **kwargs) + if impl is None: + return NotImplemented + + return impl(*args, **kwargs) + + @staticmethod + def from_interleaved(t: Tensor) -> ComplexTensor: + t_real = torch.real(t) + t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real) + return Complex.apply(t_real, t_imag) + + def as_interleaved(self) -> Tensor: + return torch.complex(self.real, self.imag) + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Tensor], + meta: Any, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + ) -> ComplexTensor: + assert meta is None + re, im = inner_tensors["re"], inner_tensors["im"] + return ComplexTensor(re, im) + + def __tensor_flatten__(self) -> tuple[list[str], Any]: + return ["re", "im"], None + + def __repr__(self, *, tensor_contents=None) -> str: + return f"ComplexTensor(real={self.re!r}, imag={self.im!r})" + + def is_pinned(self, device: DeviceLikeType | None = None) -> bool: + return self.re.is_pinned(device) + + +class Complex(Function): + @staticmethod + def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override] + return ComplexTensor(real, imag) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override] + return grad_output.real, grad_output.imag diff --git a/torch/_subclasses/complex_tensor/_ops/__init__.py b/torch/_subclasses/complex_tensor/_ops/__init__.py new file mode 100644 index 0000000000000..c07bdf6099b65 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/__init__.py @@ -0,0 +1,5 @@ +from . import aten, prims +from .common import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"] diff --git a/torch/_subclasses/complex_tensor/_ops/aten.py b/torch/_subclasses/complex_tensor/_ops/aten.py new file mode 100644 index 0000000000000..15e09c3b314f0 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/aten.py @@ -0,0 +1,921 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .._core import ComplexTensor +from .common import ( + _get_func_name, + COMPLEX_TO_REAL, + complex_to_real_dtype, + is_complex, + OpType, + promote_tensors, + register_binary_nonlinear, + register_complex, + register_error, + register_force_test, + register_simple, + split_complex_arg, + split_complex_tensor, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +aten = torch.ops.aten + + +def register_binary_linear(op: OpType): + def impl_with_alpha( + lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs + ) -> ComplexTensor: + return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs) + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + u = op(a_r, b_r, *args, **kwargs) + v = op(a_i, b_i, *args, **kwargs) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + return register_complex(op, impl) + + +@register_complex(aten.real) +def real_impl(self: ComplexTensor) -> torch.Tensor: + re, _ = split_complex_tensor(self) + return re + + +@register_complex(aten.imag) +def imag_impl(self: ComplexTensor) -> torch.Tensor: + _, im = split_complex_tensor(self) + return im + + +@register_complex(aten.is_pinned) +def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool: + return self.is_pinned(device) + + +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[_get_func_name(simple_op)] = register_simple(simple_op) + +# TODO (hameerabbasi): Not being tested +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[_get_func_name(simple_op)] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op + +# some binary ops which we can stamp out +mul_impl = register_binary_nonlinear(aten.mul) +mul__impl = register_binary_nonlinear(aten.mul_) +mm_impl = register_binary_nonlinear(aten.mm) +dot_impl = register_binary_nonlinear(aten.dot) +bmm_impl = register_binary_nonlinear(aten.bmm) + +# TODO (hameerabbasi): Not being tested +convolution_impl = register_force_test( + aten.convolution, register_binary_nonlinear(aten.convolution) +) + +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + +add_impl = register_binary_linear(aten.add) +add__impl = register_binary_linear(aten.add_) +sub_impl = register_binary_linear(aten.sub) +sub__impl = register_binary_linear(aten.sub_) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) +fill__impl = register_binary_linear(aten.fill_) + + +@register_complex(aten.rsub) +def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor: + if alpha is None: + return torch.sub(rhs, lhs) # type: ignore[bad-return] + return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return] + + +@register_complex(aten.div) +@register_complex(aten.true_divide) +def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): + if rounding_mode is not None: + raise NotImplementedError( + "`rounding_mode` other than `None` not implemented for`ComplexTensor`." + ) + a_r, a_i = split_complex_tensor(lhs) + if not is_complex(rhs): + return ComplexTensor(a_r / rhs, a_i / rhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + num_r = a_r * b_r + a_i * b_i + num_i = a_i * b_r - a_r * b_i + den = b_r * b_r + b_i * b_i + return ComplexTensor( + (num_r / den).to(out_dt), + (num_i / den).to(out_dt), + ) + + +@register_complex(aten.reciprocal) +def reciprocal_impl(self: ComplexTensor): + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + den = self_r * self_r + self_i * self_i + return ComplexTensor( + aten.div(self_r, den).to(out_dt), + aten.div(-self_i, den).to(out_dt), + ) + + +# reductions +@register_complex(aten.prod) +def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + dtype = kwargs.pop("dtype", out_dt) + kwargs["dtype"] = complex_to_real_dtype(self.dtype) + + prod_r = torch.prod(torch.abs(self), *args, **kwargs) + sum_phi = torch.sum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return] + + +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + out_dt, (self, exponent) = promote_tensors(self, exponent) + return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + +# unary funcs, +# most of these are simple or require some kind of identity +@register_complex(aten.abs) +def abs_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + result = torch.hypot(x, y) + return result.to(out_dt) + + +@register_complex(aten.angle) +def angle_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.atan2(y, x) + + +@register_complex(aten.acos) +def acos_impl(self: ComplexTensor) -> ComplexTensor: + _, y = split_complex_tensor(self) + acosh_z = torch.acosh(self) + assert isinstance(acosh_z, ComplexTensor) + acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z) + sign_im = 2 * torch.signbit(y) - 1 + return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re)) + + +@register_complex(aten.asin) +def asin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + asinh_iz = torch.asinh(ComplexTensor(-y, x)) + assert isinstance(asinh_iz, ComplexTensor) + asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz) + return ComplexTensor(asinh_iz_im, -asinh_iz_re) + + +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.atanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + ret = 0.5 * ( + torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y)) + ) + assert isinstance(ret, ComplexTensor) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.cos) +def cos_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return] + + +@register_complex(aten.cosh) +def cosh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.cosh(x) * torch.cos(y) + v = torch.sinh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + sinh_iz = torch.sinh(ComplexTensor(-y, x)) + assert isinstance(sinh_iz, ComplexTensor) + sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz) + return ComplexTensor(sinh_iz_im, -sinh_iz_re) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.sinh(x) * torch.cos(y) + v = torch.cosh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.tanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + _2x = 2 * x + _2y = 2 * y + _d = torch.cosh(_2x) + torch.cos(_2y) + _2xsh = torch.sinh(_2x) + + out_re = _2xsh / _d + out_im = torch.sin(_2y) / _d + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.exp) +def exp_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + ex = torch.exp(x) + u = ex * torch.cos(y) + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.expm1) +def expm1_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + # TODO (hameerabbasi): The two lines below may have numerical issues + ex = torch.exp(x) + u = ex * torch.cos(y) - 1 + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return] + + +@register_complex(aten.any) +def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs) + + +@register_complex(aten.all) +def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs) + + +@register_complex(aten.eq) +def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_arg(self) + b_r, b_i = split_complex_arg(rhs) + return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.ne) +def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.isnan) +def isnan_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isnan(re) | torch.isnan(im) + + +@register_complex(aten.isinf) +def isinf_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isinf(re) | torch.isinf(im) + + +@register_complex(aten.isfinite) +def isfinite_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isfinite(re) & torch.isfinite(im) + + +@register_complex(aten.isclose) +def isclose_impl( + self: ComplexTensor, + rhs: ComplexTensor, + rtol=1e-5, + atol=1e-8, + equal_nan: bool = False, +) -> torch.Tensor: + abs_diff = torch.abs(self - rhs) + abs_other = torch.abs(rhs) + basic_condition = abs_diff <= (rtol * abs_other + atol) + + # This is the nontrivial part + if equal_nan: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + + a_r_nan = torch.isnan(a_r) + b_r_nan = torch.isnan(b_r) + a_i_nan = torch.isnan(a_i) + b_i_nan = torch.isnan(b_i) + a_nan = a_r_nan | a_i_nan + + # This logical expression makes sure that the isnan of both the real and imaginary parts + # matches (so 1 + nan*i doesn't equal nan + 1*i) + equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan + return basic_condition | equal_nan_condition + + return basic_condition + + +ERROR_OPS_LIST = [ + aten.lt, + aten.le, + aten.gt, + aten.ge, + aten.amin, + aten.amax, + aten.clamp, + aten.ceil, + aten.floor, + aten.minimum, + aten.maximum, + aten.trunc, + aten.sign, + aten.argmax, + aten.argmin, + aten.sort, + aten.topk, + aten.round, + aten.fmod, +] + + +ERROR_TYPES = { + aten.minimum: RuntimeError, + aten.maximum: RuntimeError, + aten.argmax: RuntimeError, + aten.argmin: RuntimeError, + aten.sort: RuntimeError, + aten.topk: RuntimeError, +} + + +for err_op in ERROR_OPS_LIST: + globals()[_get_func_name(err_op)] = register_error( + err_op, ERROR_TYPES.get(err_op, NotImplementedError) + ) + +del err_op + + +@register_complex(aten.masked_scatter) +def masked_scatter_impl( + self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor +) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + source_r, source_i = split_complex_arg(source) + ret_r = torch.masked_scatter(self_r, mask, source_r) + ret_i = torch.masked_scatter(self_i, mask, source_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.where) +def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor: + x_r, x_i = split_complex_arg(x) + y_r, y_i = split_complex_arg(y) + + ret_r = torch.where(mask, x_r, y_r) + ret_i = torch.where(mask, x_i, y_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.full_like) +def full_like_impl( + input: ComplexTensor, + fill_value: complex, + *args, + dtype: torch.dtype | None = None, + **kwargs, +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument + input_r, input_i = split_complex_tensor(input) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[_get_func_name(like_op)] = register_like(like_op) + +del like_op + + +@register_complex(aten.cat) +def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: + tensors_r = [] + tensors_i = [] + + for t in tensors: + t_r, t_i = split_complex_arg(t) + tensors_r.append(t_r) + tensors_i.append(t_i) + + ret_r = torch.cat(tensors_r, dim=dim) + ret_i = torch.cat(tensors_i, dim=dim) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.sgn) +def sgn_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = (self_r != 0) | (self_i != 0) + masked_sgn = ComplexTensor( + (self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt) + ) + return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return] + + +@register_complex(aten.sqrt) +def sqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_sqrt = torch.sqrt(torch.abs(self)) + self_half_angle = 0.5 * torch.angle(self) + + ret_r = self_abs_sqrt * torch.cos(self_half_angle) + ret_i = self_abs_sqrt * torch.sin(self_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.rsqrt) +def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_rsqrt = torch.rsqrt(torch.abs(self)) + self_neg_half_angle = -0.5 * torch.angle(self) + + ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) + ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.addmm) +def addmm_impl( + input: ComplexTensor, + mat1: ComplexTensor, + mat2: ComplexTensor, + out_dtype: torch.dtype | None = None, + beta: complex = 1, + alpha: complex = 1, +) -> ComplexTensor: + ret = beta * input + alpha * torch.mm(mat1, mat2) + assert isinstance(ret, ComplexTensor) + ret_r, ret_i = split_complex_tensor(ret) + if out_dtype is not None: + out_dtype = COMPLEX_TO_REAL[out_dtype] + ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype) + return ComplexTensor(ret_r, ret_i) + + +def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return (re != 0) | (im != 0) + + +def register_nonzero_impl(op: OpType): + def nonzero_impl( + self: ComplexTensor, other: ComplexTensor, *args, **kwargs + ) -> torch.Tensor: + return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs) + + func_name = _get_func_name(op) + nonzero_impl.__name__ = func_name + nonzero_impl.__qualname__ = func_name + + return register_complex(op, nonzero_impl) + + +logical_and_impl = register_nonzero_impl(aten.logical_and) +logical_or_impl = register_nonzero_impl(aten.logical_or) +logical_xor_impl = register_nonzero_impl(aten.logical_xor) + + +@register_complex(aten.logical_not) +def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl(self: ComplexTensor, src, *args, **kwargs): + self_re, self_im = split_complex_tensor(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + return torch.all( + torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() # type: ignore[bad-return] + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, torch._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl( + self: ComplexTensor, pad, value: complex | None = None +) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) + + +@register_complex(aten.scatter_add) +def scatter_add_impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = torch.scatter_add(self_re, dim, index, src_re) + ret_im = torch.scatter_add(self_im, dim, index, src_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.scatter_add_) +def scatter_add__impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + out_re = self_re.scatter_add_(dim, index, src_re) + out_im = self_im.scatter_add_(dim, index, src_im) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.index_put_) +def index_put__impl( + self: ComplexTensor, + indices: tuple[torch.Tensor, ...], + values: ComplexTensor, + accumulate: bool = False, +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + values_re, values_im = split_complex_arg(values) + + out_re = self_re.index_put_(indices, values_re, accumulate=accumulate) + out_im = self_im.index_put_(indices, values_im, accumulate=accumulate) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.tanh_backward) +def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor): + return out_grad * (1.0 - y * y).conj_physical() + + +@register_complex(aten.diagonal_backward) +def diagonal_backward( + grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any: + if not isinstance(dt, torch.dtype): + return dt + + return COMPLEX_TO_REAL[dt] + + +def register_to_impl(op: OpType): + """Register an op similar to `aten.to`, but may have different signatures.""" + + def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + x, y = split_complex_tensor(self) + try: + args = tuple(_dt_to_real(a) for a in args) + kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()} + except KeyError: + return op(x, *args, **kwargs) + + return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +to_impl = register_to_impl(aten.to) +_to_copy_impl = register_to_impl(aten._to_copy) diff --git a/torch/_subclasses/complex_tensor/_ops/common.py b/torch/_subclasses/complex_tensor/_ops/common.py new file mode 100644 index 0000000000000..88532efe224bb --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/common.py @@ -0,0 +1,317 @@ +from collections.abc import Callable +from typing import Any, overload, TypeAlias +from typing_extensions import TypeIs + +import torch +from torch import Tensor +from torch._decomp import get_decompositions +from torch._ops import OpOverload, OpOverloadPacket +from torch._refs import is_complex as _is_complex +from torch.types import Number +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .._core import ComplexTensor + + +OpType: TypeAlias = OpOverloadPacket | OpOverload + +TableType: TypeAlias = dict[OpType, Callable] + +# Mapping from ops to implementations +COMPLEX_OPS_TABLE: TableType = {} + +COMPLEX_TO_REAL = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()} + +# Used to promote dtypes in `promote_real_cpu_tensors` +PROMOTE_TYPES = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.complex32: torch.complex64, +} + + +def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]: + r"""Returns True if the input is a ComplexTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.complex import ComplexTensor + >>> data = torch.zeros((3, 2), dtype=torch.complex64) + >>> ct = ComplexTensor.from_interleaved(data) + >>> is_complex_tensor(ct) + True + """ + return isinstance(obj, ComplexTensor) + + +@overload +def promote_tensors( + *tensors: ComplexTensor, +) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ... + + +@overload +def promote_tensors( + *tensors: Tensor, +) -> tuple[torch.dtype, tuple[Tensor, ...]]: ... + + +def promote_tensors( + *tensors: Tensor | ComplexTensor, +) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]: + """ + Promotes all tensors to a common dtype. + Additionally promotes CPU tensors to at least `float32`. + """ + tensor = next(t for t in tensors if isinstance(t, Tensor)) + out_dt = tensor.dtype + for t in tensors: + if isinstance(t, Tensor): + out_dt = torch.promote_types(out_dt, t.dtype) + + prom_dt = PROMOTE_TYPES.get(out_dt, out_dt) + return out_dt, tuple( + t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt) + for t in tensors + ) + + +def register_complex( + op: OpType, + func_impl: Callable | None = None, +): + """Decorator to register an implementation for some ops in some dispatch tables""" + + def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError(f"Attempted to register multiple functions for {op}") + COMPLEX_OPS_TABLE[op] = func + return func + + if func_impl is None: + return inner + + return inner(func_impl) + + +FORCE_TEST_LIST: list[OpType] = [] + + +def register_force_test(op: OpType, *args, **kwargs): + """Will attempt to test these ops even if they err on "normal" inputs""" + FORCE_TEST_LIST.append(op) + return register_complex(op, *args, **kwargs) + + +DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload] + + +def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None: + """ + Lookup an impl from the table. + + Try the particular overload first, then the overload packet. + + If nothing is found, try the decompositions with both. + """ + return COMPLEX_OPS_TABLE.get( + func, + COMPLEX_OPS_TABLE.get( + func.overloadpacket, + DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)), + ), + ) + + +def is_complex(x: Any, /) -> bool: + """Utility to detect if a given object is (known) to be complex.""" + return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex) + + +@overload +def split_complex_arg( + arg: Tensor | ComplexTensor, +) -> tuple[Tensor, Tensor]: ... + + +@overload +def split_complex_arg( + arg: complex | Number, +) -> tuple[Number, Number]: ... + + +def split_complex_arg( + arg: Tensor | ComplexTensor | complex | Number, +) -> tuple[Tensor, Tensor] | tuple[Number, Number]: + """ + Split a complex argument into a real/imaginary component. + + If real, use zero for the imaginary part. + """ + if isinstance(arg, ComplexTensor): + return split_complex_tensor(arg) + if isinstance(arg, Tensor): + if is_complex(arg): + return arg.real, arg.imag + return arg, torch.zeros_like(arg) + # TODO (hameerabbasi): Should there be a `torch.SymComplex`? + if isinstance(arg, complex): + return arg.real, arg.imag + if isinstance(arg, float | torch.SymFloat): + return arg, 0.0 + if isinstance(arg, int | torch.SymInt): + return arg, 0 + if isinstance(arg, bool | torch.SymBool): + return arg, False + raise TypeError(f"Expected tensor or number got, {type(arg)}") + + +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]: + """Split a ComplexTensor into its real and imaginary parts.""" + return complex_tensor.re, complex_tensor.im + + +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: + """Convert a complex dtype to the dtype of its real part. Return other dtypes as-is.""" + return COMPLEX_TO_REAL.get(dtype, dtype) + + +def _get_op_name(op: OpType) -> str: + """Get the op name from the op.""" + if isinstance(op, OpOverload): + op = op.overloadpacket + return str(op).split(".", 1)[1] + + +def _get_func_name(op: OpType) -> str: + """Get the name of the implementation function from the op.""" + return f"{_get_op_name(op)}_impl" + + +def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError): + msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`." + + def ordered_impl(*args, **kwargs): + raise exc_type(msg) + + func_name = _get_func_name(op) + ordered_impl.__name__ = func_name + ordered_impl.__qualname__ = func_name + + return register_force_test(op, ordered_impl) + + +def register_binary_nonlinear(op: OpType) -> Callable: + """Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ...""" + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def register_simple(op: OpType): + """Register an op which can be applied independently to the real and complex parts to get the result.""" + + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: + x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError( + "Non-complex `dtype` specified, please write custom impl." + ) + + if dtype in COMPLEX_TO_REAL: + assert dtype is not None + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + + u_flat, u_spec = tree_flatten(u) + v_flat, v_spec = tree_flatten(v) + assert u_spec == v_spec + out_flat = [ + ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False) + ] + return tree_unflatten(out_flat, u_spec) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any: + """Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is.""" + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any: + """Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is.""" + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexTensorMode(TorchDispatchMode): + _compile: bool + + """ A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """ + + def __init__(self, _dispatch_key=None, *, _compile: bool = False): + """Initialize a ComplexTensorMode. + + Args: + _dispatch_key: passed on to TorchDispatchMode + _compile: Compile the op before the computation + """ + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[type], + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + + # TODO (hameerabbasi): Test perf with `_compile` set to `True` + if self._compile: + func = torch.compile(func) # type: ignore[bad-assignment] + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/torch/_subclasses/complex_tensor/_ops/prims.py b/torch/_subclasses/complex_tensor/_ops/prims.py new file mode 100644 index 0000000000000..9a237b32d9904 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/prims.py @@ -0,0 +1,34 @@ +import torch + +from .._core import ComplexTensor +from .common import ( + complex_to_real_dtype, + register_complex, + register_force_test, + split_complex_tensor, +) + + +prims = torch.ops.prims +aten = torch.ops.aten + + +# TODO (hameerabbasi): Not being tested +@register_force_test(prims.convert_element_type) +def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: + dtype = complex_to_real_dtype(dtype) + u, v = split_complex_tensor(x) + u_out = prims.convert_element_type(u, dtype) + v_out = prims.convert_element_type(v, dtype) + + return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) diff --git a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py index 0054e996e33ce..1e49a274e129c 100644 --- a/torch/ao/nn/intrinsic/qat/modules/conv_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/conv_fused.py @@ -112,9 +112,6 @@ def reset_bn_parameters(self): bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) - def reset_parameters(self): - super().reset_parameters() - def update_bn_stats(self): self.freeze_bn = False self.bn.training = True @@ -534,44 +531,6 @@ class ConvBnReLU1d(ConvBn1d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nn.Module] | None] = nni.ConvReLU1d - def __init__( - self, - # Conv1d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm1d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -735,44 +694,6 @@ class ConvBnReLU2d(ConvBn2d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU2d] | None] = nni.ConvReLU2d - def __init__( - self, - # Conv2d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm2d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(self._forward(input)) @@ -935,44 +856,6 @@ class ConvBnReLU3d(ConvBn3d): # module class after fusing bn into conv _FUSED_FLOAT_MODULE: ClassVar[type[nni.ConvReLU3d] | None] = nni.ConvReLU3d - def __init__( - self, - # Conv3d args - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=None, - padding_mode="zeros", - # BatchNorm3d args - # num_features: out_channels - eps=1e-05, - momentum=0.1, - # affine: True - # track_running_stats: True - # Args for this module - freeze_bn=False, - qconfig=None, - ): - super().__init__( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation, - groups, - bias, - padding_mode, - eps, - momentum, - freeze_bn, - qconfig, - ) - def forward(self, input): return F.relu(ConvBn3d._forward(self, input)) diff --git a/torch/ao/nn/quantized/reference/modules/utils.py b/torch/ao/nn/quantized/reference/modules/utils.py index 8311b0e5697b0..7bdbcd4a6739e 100644 --- a/torch/ao/nn/quantized/reference/modules/utils.py +++ b/torch/ao/nn/quantized/reference/modules/utils.py @@ -202,7 +202,7 @@ def _quantize_weight_decomposed( _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { torch.uint8: (0, 255), torch.int8: (-128, 127), - torch.int32: ((-(2**31)), (2**31 - 1)), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float } # TODO: add an util function for converting qdtype to dtype @@ -265,7 +265,7 @@ def _dequantize_weight_decomposed( _DTYPE_TO_QVALUE_BOUNDS: dict[torch.dtype, tuple[int, int]] = { torch.uint8: (0, 255), torch.int8: (-128, 127), - torch.int32: ((-(2**31)), (2**31 - 1)), + torch.int32: (-2147483648, 2147483647), # torch.jit interprets 2**31 as a float } # TODO: add an util function for converting qdtype to dtype _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = { diff --git a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py index 3c146c55947a0..e2b31e0e563bf 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/benchmarks/dlrm_utils.py @@ -19,9 +19,6 @@ class SparseDLRM(DLRM_Net): layer of the top layer. """ - def __init__(self, **args): - super().__init__(**args) - def forward(self, dense_x, lS_o, lS_i): # pyrefly: ignore [missing-attribute] x = self.apply_mlp(dense_x, self.bot_l) # dense features diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index b0f1b823b7fdb..3b8ef1030bfdc 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -41,9 +41,6 @@ def get_embedding_operators_config() -> OperatorConfig: class EmbeddingQuantizer(Quantizer): - def __init__(self) -> None: - super().__init__() - @classmethod def get_supported_quantization_configs(cls) -> list[QuantizationConfig]: op_configs: set[QuantizationConfig] = { diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index d19968c2787f4..1c0fc48fd54fa 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -75,9 +75,6 @@ class XPUInductorQuantizer(X86InductorQuantizer): of the optimized kernels in oneDNN library. """ - def __init__(self) -> None: - super().__init__() - """ Following annotate_xx overrides the impls in base class, as no XPU implementation for these operators currently. We would diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 5c8d2664ed7db..cfab4fa5e2d5f 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -532,7 +532,7 @@ def vjp(gO): result = tuple( output if output is not None - else torch.zeros_like(input, requires_grad=True) + else torch.zeros_like(input, requires_grad=create_graph) for (output, input) in zip(result, inputs) ) return result diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index c02a8c36fd08b..f54a3fd6820c7 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -113,9 +113,6 @@ def inner(precision): class GenericModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - fp32_precision = ContextProp( _get_fp32_precision_getter("generic", "all"), _set_fp32_precision_setter("generic", "all"), diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 697783c01cb64..267594531db3d 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -198,9 +198,6 @@ def flags( class CudnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp( torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic diff --git a/torch/backends/miopen/__init__.py b/torch/backends/miopen/__init__.py index 93453cc11592d..1b270b658e31a 100644 --- a/torch/backends/miopen/__init__.py +++ b/torch/backends/miopen/__init__.py @@ -37,9 +37,6 @@ def flags( class MiopenModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - immediate = ContextProp( torch._C._get_miopen_immediate, torch._C._set_miopen_immediate ) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 2d1ce8f3bb997..58e6b2c595e98 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -110,9 +110,6 @@ def flags(enabled=False, deterministic=False, allow_tf32=True, fp32_precision="n class MkldnnModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - def is_available(self): return is_available() diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index 797d847e31e5c..264be78aa9a1c 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -101,9 +101,6 @@ def flags(enabled=None, strategy=None): class OptEinsumModule(PropModule): - def __init__(self, m, name): - super().__init__(m, name) - global enabled enabled = ContextProp(_get_enabled, _set_enabled) global strategy diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index d580809460811..adba98beb2724 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -138,7 +138,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { throw; \ } \ } \ - catch (const std::exception& e) { \ + catch (const std::exception&) { \ torch::translate_exception_to_python(std::current_exception()); \ return retval; \ } diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index 62bc48fa9b983..ea39424cf8ea7 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -219,6 +219,23 @@ static PySequenceMethods THPSize_as_sequence = { nullptr /* sq_contains */ }; +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 +static Py_hash_t THPSize_hash(PyObject* self) { + /* + Python 3.14 introduce a caching mechanism for tuple hashing which is stored + in the `ob_hash` field. The caching mechanism relies on a sentinel value (-1) + to indicate the hash has not yet been computed. + For some unknown reason, this field is initialized with 0 when Size is + created, which causes the caching logic to behave incorrectly. + */ + PyTupleObject* v = _PyTuple_CAST(self); + // reset ob_hash and force hash to be recomputed + Py_hash_t sentinel = -1; + v->ob_hash = sentinel; + return PyTuple_Type.tp_hash(self); +} +#endif + static PyMappingMethods THPSize_as_mapping = { nullptr, /* mp_length */ wrap_tuple_fn, @@ -284,7 +301,11 @@ PyTypeObject THPSizeType = { &THPSize_as_number, /* tp_as_number */ &THPSize_as_sequence, /* tp_as_sequence */ &THPSize_as_mapping, /* tp_as_mapping */ - nullptr, /* tp_hash */ +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 + &THPSize_hash, /* tp_hash */ +#else + nullptr, /* tp_hash */ +#endif nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ @@ -294,7 +315,12 @@ PyTypeObject THPSizeType = { nullptr, /* tp_doc */ nullptr, /* tp_traverse */ nullptr, /* tp_clear */ +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 14 + // if tp_hash is defined, one must also defines tp_richcompare + PyTuple_Type.tp_richcompare, /* tp_richcompare */ +#else nullptr, /* tp_richcompare */ +#endif 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 671c28adef3e3..33dfa3132cb45 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -245,7 +245,7 @@ static PyObject* THPStorage_pynew( storage_set(storage, i, value); } } - } catch (const std::exception& e) { + } catch (const std::exception&) { TORCH_CHECK( THPStorageStr "(): tried to construct a storage from a sequence (", THPUtils_typename(sequence), diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index c16cbb2331f07..a4a9afec1a7cc 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -51,6 +51,20 @@ void _foreach_tensor( } } +[[maybe_unused]] +size_t expected_fresh_use_count(const at::Tensor& self) { + if (!self.defined()) { + // An UndefinedTensorImpl always has a use count of 0 + return 0; + } + if (self.unsafeGetTensorImpl()->pyobj_slot()->load_pyobj() != nullptr) { + // A TensorImpl with a Python object has a use count of 2 + return 2; + } + // A fresh TensorImpl (with no PyObject) has a use count of 1 + return 1; +} + AutogradFallbackMode kAutogradFallbackMode = AutogradFallbackMode::Warn; } // namespace @@ -420,8 +434,7 @@ static void autogradNotImplementedFallbackImpl( op_name == "aten::_test_optional_floatlist") return; if (!is_inplace_output[idx_ret]) - TORCH_INTERNAL_ASSERT( - t.use_count() <= 1, op_name); // Okay to return undefined tensor + TORCH_INTERNAL_ASSERT(t.use_count() == expected_fresh_use_count(t)); // note(crcrpar): `_foreach_norm` returns a list of scalar Tensors and // each Tensor shares a storage of a hidden, intermediate 1D Tensor // created inside the CUDA implementation. This is because the diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index a13cc70270ccb..36a8806d281ed 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,31 +390,25 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ torch::profiler::impl::ActivityType::CPU}; -#if defined(USE_KINETO) && \ - (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) - if (at::hasMTIA()) { - activities.insert(torch::profiler::impl::ActivityType::MTIA); - } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#if defined(USE_KINETO) +#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::getNumGPUs() > 0) { activities.insert(torch::profiler::impl::ActivityType::CUDA); } -#elif defined(USE_KINETO) +#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::hasXPU()) { activities.insert(torch::profiler::impl::ActivityType::XPU); } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } if (at::hasMTIA()) { activities.insert(torch::profiler::impl::ActivityType::MTIA); } + if (at::hasHPU()) { + activities.insert(torch::profiler::impl::ActivityType::HPU); + } if (c10::get_privateuse1_backend() != "privateuseone") { activities.insert(torch::profiler::impl::ActivityType::PrivateUse1); } -#endif +#endif // defined(USE_KINETO) return activities; }); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 6d0bf5d0a8579..150512c972684 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1200,25 +1200,27 @@ get_thread_local_native_sharding_propagator_cache() { py::reinterpret_borrow(PyThreadState_GetDict()); // We need to clean up before Python detaches from the thread if // the thread is being destroyed. - thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = - py::capsule(new std::thread::id(this_thread_id), [](void* p) { - auto* ptid = reinterpret_cast(p); - { - std::lock_guard inner_lock( - native_sharding_propagator_cache_cleanup_mutex); - auto it = all_thread_caches.find(*ptid); - if (it != all_thread_caches.end()) { - // We need to both: - // 1) free python objects, and - it->second->reset(); - // 2) make sure we don't try to come back and mess with - // a destroyed thread-local at module unload (e.g., - // process exit) time. - all_thread_caches.erase(it); + if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) { + thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = + py::capsule(new std::thread::id(this_thread_id), [](void* p) { + auto* ptid = reinterpret_cast(p); + { + std::lock_guard inner_lock( + native_sharding_propagator_cache_cleanup_mutex); + auto it = all_thread_caches.find(*ptid); + if (it != all_thread_caches.end()) { + // We need to both: + // 1) free python objects, and + it->second->reset(); + // 2) make sure we don't try to come back and mess with + // a destroyed thread-local at module unload (e.g., + // process exit) time. + all_thread_caches.erase(it); + } } - } - delete ptid; - }); + delete ptid; + }); + } } return native_sharding_propagator_cache_DO_NOT_USE.value(); } @@ -1424,6 +1426,17 @@ py::object dispatchDTensorOp( torch::jit::Stack saved_args = *stack; NativeShardingPropagatorCache* native_sharding_propagator_cache = nullptr; + // In the original Python implementation of DTensor dispatch, the creation + // of OpInfo (which includes the OpSchema computed here) never fails. However, + // C++ support for all the features of OpSchema are not supported; in this + // case opt_native_op_schema is nullopt. In this case, we need to fallback + // to the Python logic for doing so. If you are comparing against the old + // Python code, this is a bit tricky, since the Python 'dispatch' function + // has been completely deleted. + + // First, we will try to short-circuit Python entirely using the fast path. + // Here, we never materialize OpInfo, we generate a gimped NativeOpSchema + // object which has exactly the information you need to do a hash lookup. auto opt_native_op_schema = create_native_op_schema(op, py_op, stack); if (opt_native_op_schema.has_value()) { native_sharding_propagator_cache = @@ -1431,21 +1444,30 @@ py::object dispatchDTensorOp( cached_sharding = native_sharding_propagator_cache->find(opt_native_op_schema->first); } + py::object py_op_info; if (!cached_sharding) { + // OK, the C++ fastpath failed. Let's use the Python path to generate the + // OpInfo (which is guaranteed to work), which we will need to either + // redo the cache lookup or compute the value for real. py_op_info = checked_vectorcall( op_dispatcher.attr("unwrap_to_op_info").ptr(), py_op.ptr(), args.ptr(), kwargs.ptr()); + py::object sharding = checked_vectorcall( - op_dispatcher - .attr("_propagate_op_sharding_non_cached_dispatch_slow_path") - .ptr(), + op_dispatcher.attr("_propagate_op_sharding_dispatch_slow_path").ptr(), py_op.ptr(), args.ptr(), kwargs.ptr(), - py_op_info.ptr()); + py_op_info.ptr(), + /*try_cache*/ !opt_native_op_schema.has_value() ? Py_True : Py_False); + // This is a hack, because the dispatch slow path sometimes returns + // a sharding result (in which case we need to keep going) but it + // will sometimes just decompose and directly return a Tensor result, + // in which case we should return immediately. In this case, sharding + // is not a sharding at all; it's the real result! if (!py::isinstance(sharding, get_output_sharding_class())) { stack->clear(); return sharding; diff --git a/torch/csrc/cuda/shim_common.cpp b/torch/csrc/cuda/shim_common.cpp new file mode 100644 index 0000000000000..cb5f28dba0152 --- /dev/null +++ b/torch/csrc/cuda/shim_common.cpp @@ -0,0 +1,9 @@ +#include +#include +#include + +AOTITorchError torch_get_current_cuda_blas_handle(void** ret_handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + *(cublasHandle_t*)(ret_handle) = at::cuda::getCurrentCUDABlasHandle(); + }); +} diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 156c9efd5ca98..2104e3030d445 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -448,7 +448,7 @@ c10::intrusive_ptr DistEngine:: const variable_list& grads = futureGrads.constValue().toTensorVector(); TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); accumulateGradFuture->markCompleted(c10::IValue()); - } catch (std::exception& e) { + } catch (std::exception&) { accumulateGradFuture->setErrorIfNeeded(std::current_exception()); } }); @@ -527,7 +527,7 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( // Perform cleanup at the end of the backward pass (before // we mark the future as completed). DistEngine::getInstance().cleanupBackwardPass(autogradContext); - } catch (std::exception& e) { + } catch (std::exception&) { callbackFuture->setErrorIfNeeded(std::current_exception()); return; } @@ -539,7 +539,7 @@ c10::intrusive_ptr DistEngine::executeSendFunctionAsync( callbackFuture->setError(rpcFuture.exception_ptr()); } }); - } catch (std::exception& e) { + } catch (std::exception&) { callbackFuture->setErrorIfNeeded(std::current_exception()); } }); diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 9f79a09d236e5..b888e315021ac 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -81,7 +81,7 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED}; try { backendType = deviceTypeToBackendType_.at(deviceType); - } catch (const std::out_of_range& e) { + } catch (const std::out_of_range&) { TORCH_CHECK( false, "No backend type associated with device type ", deviceType); } diff --git a/torch/csrc/distributed/c10d/PyProcessGroup.hpp b/torch/csrc/distributed/c10d/PyProcessGroup.hpp index afec6bbe11a9a..39474c49052fe 100644 --- a/torch/csrc/distributed/c10d/PyProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/PyProcessGroup.hpp @@ -339,7 +339,7 @@ class TORCH_PYTHON_API PythonOnCompletionHook { eptr = std::make_exception_ptr(std::runtime_error(e.what())); e.restore(); PyErr_Clear(); - } catch (std::exception& e) { + } catch (std::exception&) { eptr = std::current_exception(); } } diff --git a/torch/csrc/distributed/c10d/TCPStore.cpp b/torch/csrc/distributed/c10d/TCPStore.cpp index 0c1cf581887d1..9f566032b5b3c 100644 --- a/torch/csrc/distributed/c10d/TCPStore.cpp +++ b/torch/csrc/distributed/c10d/TCPStore.cpp @@ -270,7 +270,7 @@ TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) // server successfully started C10D_DEBUG("The server has started on port = {}.", server_->port()); addr_.port = server_->port(); - } catch (const SocketError& e) { + } catch (const SocketError&) { bool useAgentStore = getCvarBool({"TORCHELASTIC_USE_AGENT_STORE"}, false); int masterPort = getCvarInt({"MASTER_PORT"}, 0); if (useAgentStore && masterPort == opts.port) { diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index f3ff9e623043e..7427848b8445b 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -246,7 +246,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } @@ -322,7 +322,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 10274d053b995..5e5c3195046cb 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -9,6 +11,8 @@ #include #include +#include + namespace c10d::control_plane { namespace { @@ -63,6 +67,30 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { res.setStatus(200); }}; +RegisterHandler frTracehandler( + "fr_trace_json", + [](const Request&, Response& res) { + auto trace = ::c10d::dump_fr_trace_json(true, true); + res.setContent(std::move(trace), "application/json"); + res.setStatus(200); + }); + +RegisterHandler waitCounterHandler{ + "wait_counter_values", + [](const Request&, Response& res) { + // Get all wait counter values from our tracking backend + res.setContent(getWaitCounterValuesJson(), "application/json"); + res.setStatus(200); + }}; + +#if !defined(FBCODE_CAFFE2) +// Initialize the wait counter backend +[[maybe_unused]] static bool init_backend = []() { + ensureWaitCounterBackendRegistered(); + return true; +}(); +#endif + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 70333a3a4844c..58ae9368ea212 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -18,6 +18,14 @@ class TORCH_API Request { virtual const std::string& body() const = 0; virtual const std::multimap& params() const = 0; + + std::string getParam(const std::string& key) const { + auto it = params().find(key); + if (it != params().end()) { + return it->second; + } + return ""; + } }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp new file mode 100644 index 0000000000000..194901cea6837 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.cpp @@ -0,0 +1,138 @@ +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace c10d::control_plane { + +namespace { + +// Data structure to hold counter metrics +struct CounterData { + std::atomic active_count{0}; + std::atomic total_calls{0}; + std::atomic total_time_us{0}; + std::atomic max_time_us{0}; +}; + +// Holder struct for the counter data map +struct CounterDataMapHolder { + c10::Synchronized< + std::unordered_map>> + map; +}; + +// Leaky singleton to avoid static destruction order issues +CounterDataMapHolder* getCounterDataMapHolder() { + static CounterDataMapHolder* holder = new CounterDataMapHolder(); + return holder; +} + +// Backend implementation that tracks counter metrics +class TrackingBackend : public c10::monitor::detail::WaitCounterBackendIf { + public: + explicit TrackingBackend(std::string key) : key_(std::move(key)) { + // Get or create counter data for this key + getCounterDataMapHolder()->map.withLock([&](auto& map) { + auto it = map.find(key_); + if (it == map.end()) { + data_ = std::make_shared(); + map[key_] = data_; + } else { + data_ = it->second; + } + }); + } + + intptr_t start(std::chrono::steady_clock::time_point now) noexcept override { + data_->active_count.fetch_add(1, std::memory_order_relaxed); + data_->total_calls.fetch_add(1, std::memory_order_relaxed); + // Return the start time as the context + return static_cast( + std::chrono::duration_cast( + now.time_since_epoch()) + .count()); + } + + void stop(std::chrono::steady_clock::time_point now, intptr_t ctx) noexcept + override { + // Calculate duration from the stored start time + auto start_ns = std::chrono::nanoseconds(ctx); + auto start_time = std::chrono::steady_clock::time_point(start_ns); + auto duration_us = + std::chrono::duration_cast(now - start_time) + .count(); + + data_->active_count.fetch_sub(1, std::memory_order_relaxed); + data_->total_time_us.fetch_add(duration_us, std::memory_order_relaxed); + + // Update max_time_us using compare-and-swap + int64_t current_max = data_->max_time_us.load(std::memory_order_relaxed); + while (duration_us > current_max) { + if (data_->max_time_us.compare_exchange_weak( + current_max, duration_us, std::memory_order_relaxed)) { + break; + } + } + } + + private: + std::string key_; + std::shared_ptr data_; +}; + +// Factory for creating tracking backends +class TrackingBackendFactory + : public c10::monitor::detail::WaitCounterBackendFactoryIf { + public: + std::unique_ptr create( + std::string_view key) noexcept override { + return std::make_unique(std::string(key)); + } +}; + +} // namespace + +// Ensures the wait counter backend is registered +// NOTE: This function is in the c10d::control_plane namespace, NOT anonymous +void ensureWaitCounterBackendRegistered() { + static c10::once_flag once; + c10::call_once(once, []() { + c10::monitor::detail::registerWaitCounterBackend( + std::make_unique()); + }); +} + +// Returns all wait counter values as a JSON string +// NOTE: This function is in the c10d::control_plane namespace, NOT anonymous +std::string getWaitCounterValuesJson() { + nlohmann::json j = nlohmann::json::object(); + + getCounterDataMapHolder()->map.withLock([&](const auto& map) { + for (const auto& [name, data] : map) { + nlohmann::json counter_obj = nlohmann::json::object(); + counter_obj["active_count"] = + data->active_count.load(std::memory_order_relaxed); + counter_obj["total_calls"] = + data->total_calls.load(std::memory_order_relaxed); + counter_obj["total_time_us"] = + data->total_time_us.load(std::memory_order_relaxed); + counter_obj["max_time_us"] = + data->max_time_us.load(std::memory_order_relaxed); + j[name] = std::move(counter_obj); + } + }); + + return j.dump(); +} + +} // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp new file mode 100644 index 0000000000000..417e4d21edbd0 --- /dev/null +++ b/torch/csrc/distributed/c10d/control_plane/WaitCounterHandler.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace c10d { +namespace control_plane { + +// Returns all wait counter values as a JSON string +std::string getWaitCounterValuesJson(); + +// Ensures the wait counter backend is registered +void ensureWaitCounterBackendRegistered(); + +} // namespace control_plane +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 8bbe857620790..eda6ee3a91488 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { TORCH_CHECK( server_.bind_to_port(hostOrFile, 80), fmt::format("Error binding to {}", hostOrFile)); + } else if (port == 0) { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + port_ = server_.bind_to_any_port(hostOrFile); + TORCH_CHECK( + port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port)); } else { C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); TORCH_CHECK( server_.bind_to_port(hostOrFile, port), fmt::format("Error binding to {}:{}", hostOrFile, port)); + port_ = port; } serverThread_ = std::thread([this]() { diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp index 41c1356fc01f3..20d05b7509e92 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target { void shutdown(); + int port() { + return port_; + } + private: httplib::Server server_; std::thread serverThread_; + int port_; }; } // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 94a8c0bbe228b..255e793eaa4df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include #include @@ -4209,7 +4210,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. }), py::arg("host_or_file"), py::arg("port") = -1) - .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown) + .def_property_readonly( + "port", &::c10d::control_plane::WorkerServer::port); module.def( "_get_handler", @@ -4225,6 +4228,25 @@ such as `dist.all_reduce(tensor, async_op=True)`. Returns the handler with the specified name. )"); + module.def( + "_register_handler", + [](const std::string& name, const py::function& handler) { + ::c10d::control_plane::registerHandler( + name, + [handler]( + const ::c10d::control_plane::Request& req, + ::c10d::control_plane::Response& res) { + py::gil_scoped_acquire acquire; + handler(std::ref(req), std::ref(res)); + }); + }, + + py::arg("name"), + py::arg("handler"), + R"( + Registers a handler by name. + )"); + module.def( "_get_handler_names", &::c10d::control_plane::getHandlerNames, @@ -4242,12 +4264,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. // Default constructor. .def(py::init<>()) .def("body", &::c10d::control_plane::Request::body) - .def("params", &::c10d::control_plane::Request::params); + .def("get_param", &::c10d::control_plane::Request::getParam); - py::class_< - ::c10d::control_plane::Response, - std::shared_ptr<::c10d::control_plane::Response>, - PythonResponse>( + py::class_<::c10d::control_plane::Response, PythonResponse>( module, "_Response", R"( diff --git a/torch/csrc/distributed/c10d/python_callback_work.cpp b/torch/csrc/distributed/c10d/python_callback_work.cpp index 47bef1831a480..685b3cceeaa4c 100644 --- a/torch/csrc/distributed/c10d/python_callback_work.cpp +++ b/torch/csrc/distributed/c10d/python_callback_work.cpp @@ -40,14 +40,14 @@ bool PythonCallbackWork::wait(std::chrono::milliseconds timeout) { } return success; - } catch (py::error_already_set& e) { + } catch (py::error_already_set&) { // Capture the Python exception and store it finish(std::current_exception()); if (!future_->completed()) { future_->setErrorIfNeeded(std::current_exception()); } throw; - } catch (const std::exception& e) { + } catch (const std::exception&) { // Capture any C++ exception and store it finish(std::current_exception()); if (!future_->completed()) { diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index a1c9b4a3039d5..c4af19ef44209 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -341,7 +341,7 @@ void Reducer::check_grad_layout( grad.sizes(), ", strides() = ", grad.strides(), - "\n", + '\n', "bucket_view.sizes() = ", bucket_view.sizes(), ", strides() = ", diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index f83d42df4ac68..6352330c3872c 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -517,11 +517,6 @@ static void init_multicast_for_block( using McHandleType = std::conditional_t; - McHandleType invalidator; - std::memset(&invalidator, UINT8_MAX, sizeof(McHandleType)); - - // Phase 1: export handle (rank 0 only) - McHandleType mc_exported_handle{}; if (rank == 0) { CUmulticastObjectProp mc_prop{}; mc_prop.numDevices = world_size; @@ -530,82 +525,68 @@ static void init_multicast_for_block( // create a multicast object, which acts as a handle that allows multiple // devices or processes to access the same memory allocation coherently. - try { - C10_CUDA_DRIVER_CHECK( - driver_api->cuMulticastCreate_(&mc_handle, &mc_prop)); - // using the CUDA Driver API to export a multicast object into a POSIX file - // descriptor. - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &mc_exported_handle, mc_handle, handleType, 0)); - } catch (const std::exception& e) { - // Allow peers gracefully skip multicast initialization by sending -1 - mc_exported_handle = invalidator; + auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop); + if (err != CUDA_SUCCESS) { + const char* err_str; + CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str); + if (get_error_str_err != CUDA_SUCCESS) { + err_str = "unknown cuda driver error"; + } LOG(WARNING) - << "SymmetricMemory: fail to export multicast handle.\n" - << e.what(); + << "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str + << "\". Gracefully skipping multicast initialization. " + << "However, this is unexpected. Please report the issue on GitHub."; + // Allow peers gracefully skip multicast initialization by sending -1 + // TODO: allow graceful skip for fabric + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, -1); + } + return; } - } - - // Phase 2: Exchange handle - McHandleType recv_handle; - if constexpr (!use_fabric_handle) { - recv_handle = ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); - } else { - // TODO implement storeExchange.broadcast - auto gathered_handles = storeExchange.all_gather(store, rank, world_size, mc_exported_handle); - recv_handle = std::move(gathered_handles[0]); - } - - // Check exchange result - if (memcmp(&recv_handle, &invalidator, sizeof(McHandleType)) == 0) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } - // Flip to true after all CUDA steps finish - bool success_end = false; + McHandleType mc_exported_handle; + // using the CUDA Driver API to export a multicast object into a POSIX file + // descriptor. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &mc_exported_handle, mc_handle, handleType, 0)); + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); + // Ref count is incremented as soon as SCM_RIGHTS send happens + close(mc_exported_handle); + } else { + // TODO implement storeExchange.broadcast + storeExchange.all_gather(store, rank, world_size, mc_exported_handle); + } - // Phase 3: Import handle (non-0 ranks only) - if (rank != 0) { + } else { if constexpr (!use_fabric_handle) { + int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1); + if (mc_fd == -1) { + return; + } // Convert back to a handle from the broadcasted POSIX file descriptor. - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( &mc_handle, - (void*)(uintptr_t)recv_handle, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all); + (void*)(uintptr_t)mc_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(mc_fd); } else { - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( - &mc_handle, (void*)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all); + CUmemFabricHandle null_handle{}; + auto mc_handles = + storeExchange.all_gather(store, rank, world_size, null_handle); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &mc_handle, (void*)&(mc_handles[0]), CU_MEM_HANDLE_TYPE_FABRIC)); } } - // Phase 4: Bind memory // All rank adds their physical allocation to the multicast object - C10_CUDA_DRIVER_CHECK_GOTO( - driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx), check_all); - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMulticastBindMem_( - mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0), check_all); - - success_end = true; - -check_all: - // Whether all ranks have succeeded - bool all_succeed = true; - auto rank_successes = storeExchange.all_gather(store, rank, world_size, success_end); - for (int r = 0; r < world_size; ++r) { - all_succeed &= rank_successes[r]; - } - // Close the file descriptor before exit - if constexpr (!use_fabric_handle) { - close(recv_handle); - } - if (!all_succeed) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_( + mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0)); - // Phase 5: Map to virtual memory map_block(&mc_addr, mc_handle, block->block_size, block->device_idx); + storeExchange.barrier(store, rank, world_size); #endif } diff --git a/torch/csrc/distributed/rpc/python_call.cpp b/torch/csrc/distributed/rpc/python_call.cpp index 5973903cfcf10..4770d744215c6 100644 --- a/torch/csrc/distributed/rpc/python_call.cpp +++ b/torch/csrc/distributed/rpc/python_call.cpp @@ -6,6 +6,12 @@ PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution) : serializedPyObj_(std::move(serializedPyObj)), isAsyncExecution_(isAsyncExecution) {} +#if defined(__GNUC__) && __GNUC__ == 14 +/* this warning is falsely triggered with gcc-14 in following function. */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wfree-nonheap-object" +#endif + c10::intrusive_ptr PythonCall::toMessageImpl() && { std::vector payload; payload.reserve(serializedPyObj_.payload_.length() + 1); @@ -21,6 +27,10 @@ c10::intrusive_ptr PythonCall::toMessageImpl() && { MessageType::PYTHON_CALL); } +#if defined(__GNUC__) && __GNUC__ == 14 +#pragma GCC diagnostic pop +#endif + std::unique_ptr PythonCall::fromMessage(const Message& message) { TORCH_INTERNAL_ASSERT( !message.payload().empty(), diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 790ff9acff3a1..9ed9a465642c3 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -224,6 +224,7 @@ void initDynamoBindings(PyObject* torch) { .def_readonly("code", &CacheEntry::code) .def_readonly("compile_id", &CacheEntry::compile_id) .def_readonly("trace_annotation", &CacheEntry::trace_annotation) + .def_readonly("backend", &CacheEntry::backend) .def_property_readonly("next", &CacheEntry::next) .def( "update_diff_guard_root_manager", diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 11659cc24eb89..117324796e7f8 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -353,7 +353,7 @@ static PyObject* NodeBase__update_args_kwargs( Py_CLEAR(node->_kwargs); node->_kwargs = map_aggregate(args[1], visit_fn); Py_RETURN_NONE; - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -397,7 +397,7 @@ static PyObject* NodeBase__replace_input_with( PyObject* update_args[2] = {new_args.get(), new_kwargs.get()}; return NodeBase__update_args_kwargs(self, update_args, 2); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -802,7 +802,7 @@ static PyObject* py_map_aggregate( // args[0]: aggregate, args[1]: callable fn return map_aggregate( args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } @@ -824,7 +824,7 @@ static PyObject* py_map_arg( } return Py_NewRef(a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } diff --git a/torch/csrc/jit/codegen/onednn/kernel.cpp b/torch/csrc/jit/codegen/onednn/kernel.cpp index c5421643e8c43..85afc5fa8dc7b 100644 --- a/torch/csrc/jit/codegen/onednn/kernel.cpp +++ b/torch/csrc/jit/codegen/onednn/kernel.cpp @@ -28,7 +28,7 @@ LlgaKernel::LlgaKernel(const Node* fusionNode) partition_ = partitions[0]; nPartitionInputs_ = partition_.get_input_ports().size(); #ifdef GRAPH_DEBUG_ENABLED - GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString()); + GRAPH_DEBUG("Initialized ", debugName(), '\n', graph_->toString()); #endif } @@ -243,7 +243,7 @@ compiled_partition LlgaKernel::compile(const partition& partition) { void LlgaKernel::run(Stack& stack) { #ifdef GRAPH_DEBUG_ENABLED - GRAPH_DEBUG("In ", debugName(), "\n"); + GRAPH_DEBUG("In ", debugName(), '\n'); #endif // Grab input values from stack diff --git a/torch/csrc/jit/ir/subgraph_matcher.cpp b/torch/csrc/jit/ir/subgraph_matcher.cpp index 17a82dc4ac6c3..37dd8e3280de9 100644 --- a/torch/csrc/jit/ir/subgraph_matcher.cpp +++ b/torch/csrc/jit/ir/subgraph_matcher.cpp @@ -272,8 +272,8 @@ bool SubgraphMatcher::matchNodes(const Node* n1, Node* n2) { if (!endsWith(real_typename, pattern_typename)) { GRAPH_DEBUG( "Nodes did not match because expected module type is different:\n"); - GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); - GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); + GRAPH_DEBUG(" actualtype: ", real_typename, '\n'); + GRAPH_DEBUG(" expected type: ", pattern_typename, '\n'); GRAPH_DEBUG("Nodes:", *n1, *n2); return false; } diff --git a/torch/csrc/jit/jit_log.h b/torch/csrc/jit/jit_log.h index 49851e81082a7..333c3edcdd01f 100644 --- a/torch/csrc/jit/jit_log.h +++ b/torch/csrc/jit/jit_log.h @@ -95,12 +95,12 @@ TORCH_API std::ostream& operator<<( JIT_LOG( \ ::torch::jit::JitLoggingLevels::GRAPH_DUMP, \ MSG, \ - "\n", \ + '\n', \ ::torch::jit::log_function(G)); // use GRAPH_DUMP for dumping graphs after optimization passes #define GRAPH_DUMP(MSG, G) \ JIT_LOG( \ - ::torch::jit::JitLoggingLevels::GRAPH_DUMP, MSG, "\n", (G)->toString()); + ::torch::jit::JitLoggingLevels::GRAPH_DUMP, MSG, '\n', (G)->toString()); // use GRAPH_UPDATE for reporting graph transformations (i.e. node deletion, // constant folding, CSE) #define GRAPH_UPDATE(...) \ diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 2ae32f5fc5082..57dc2552c661c 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1464,9 +1464,15 @@ class ShapePropagator : public PropertyPropBase { "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", }, [](Node* node) -> type_vec_t { diff --git a/torch/csrc/jit/python/pybind.h b/torch/csrc/jit/python/pybind.h index 066ff7f77f56c..845beb540c9f1 100644 --- a/torch/csrc/jit/python/pybind.h +++ b/torch/csrc/jit/python/pybind.h @@ -117,7 +117,7 @@ struct type_caster { try { value = torch::jit::toTypeInferredIValue(src); return true; - } catch (std::exception& e) { + } catch (std::exception&) { return false; } } @@ -142,7 +142,7 @@ struct type_caster { std::string src_str; try { src_str = py::cast(src); - } catch (std::exception& e) { + } catch (std::exception&) { return false; } value = torch::jit::Symbol::fromQualString(src_str); diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp index 6ec55f998cce0..99dd289fb0964 100644 --- a/torch/csrc/jit/tensorexpr/block_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -351,7 +351,7 @@ void BlockCodeGen::Initialize() { stmt_v->accept(printer_.get()); - GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n"); + GRAPH_DEBUG("Generated Block code: ", oss_.str(), '\n'); } void BlockCodeGen::call(const std::vector& args) { diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index b19a8b8964ad5..04034554e25ed 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -320,7 +320,7 @@ void CodeGen::allocIntermediateBufs() { set_stmt(stmt_new); } - GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), "\n"); + GRAPH_DEBUG("\nMemory Allocation:\n\n", *stmt(), '\n'); } } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index cc15663720383..d696d29bf733e 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -795,7 +795,7 @@ static void parallelizeOuterLoops(LoopNest& l, const Bufs& bufs) { StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { torch::jit::tensorexpr::LoopNest l(std::move(st), bufOutputs_); LoopNest::sanitizeNames(l.root_stmt()); - GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); + GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), '\n'); int64_t random_tr_seed = randomTransformsRequested(); if (random_tr_seed) { if (random_tr_seed == -1) @@ -939,7 +939,7 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) { StmtPtr stmt = l.root_stmt(); // Arithmetic Simplification. stmt = IRSimplifier::simplify(stmt); - GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), "\n"); + GRAPH_DEBUG("Final Stmt:\n", std::to_string(stmt), '\n'); return stmt; } diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 918d82579444f..17db0872eb78f 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -777,7 +777,7 @@ void LLVMCodeGenImpl::emitKernel( PM.run(*module_); asmCode_ = asmStream.str().str(); - GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, "\n"); + GRAPH_DEBUG("\nLLVM generated assembly code\n\n", asmCode_, '\n'); } // TODO: The binary ops are copypaste. diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 29b2b94af4472..b46e1d19bcd0e 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -112,59 +112,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT std::map> opStack; // Uses the underlying TensorImpl object pointer as the key and map to its // unique id. - std::map objectId; - - using weak_storage_ptr = c10::weak_intrusive_ptr; - std::unordered_map data_ptr_to_storage_id; - std::unordered_map - data_ptr_to_weak_storage_ptr; - - ID get_tensor_storage_ID(const c10::Storage& t_storage) { - const std::lock_guard lock(gMutex); - - const void* raw_data_ptr = nullptr; - bool should_track_liveness = false; - // FakeTensor/FunctionalTensor may clear the Storage handle entirely or use - // a nullptr data pointer. Treat both cases as a shared cache key but avoid - // touching the weak-ref table so they can reuse the same ID without - // tripping the liveness check. - if (t_storage.unsafeGetStorageImpl()) { - raw_data_ptr = t_storage.data(); - should_track_liveness = raw_data_ptr != nullptr; - } - - auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - if (!should_track_liveness) { - if (id_iter != data_ptr_to_storage_id.end()) { - return id_iter->second; - } - ID id = storage_id_++; - data_ptr_to_storage_id.emplace(raw_data_ptr, id); - return id; - } - - auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); - if (weak_iter == data_ptr_to_weak_storage_ptr.end()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.emplace( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - if (weak_iter->second.expired()) { - ID id = storage_id_++; - data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); - data_ptr_to_weak_storage_ptr.insert_or_assign( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } - - id_iter = data_ptr_to_storage_id.find(raw_data_ptr); - TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end()); - return id_iter->second; - } + std::map objectId{}; // Observer run state. enum class RunState { uninitialized, disabled, enabled }; @@ -227,8 +176,6 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT // 1 -> root ID // 2 ... -> regular node ID std::atomic id_{2}; - - std::atomic storage_id_{1}; }; // Using a singleton manager here to allow init and delete the observer object. @@ -499,8 +446,8 @@ convertIValue( // symbolic sizes/strides implies t->storage_offset() will fail if (tensor_impl->has_storage() && !tensor_impl->has_symbolic_sizes_strides()) { - const c10::Storage& t_storage = tensor_impl->storage(); - storage_id = ob.get_tensor_storage_ID(t_storage); + auto& t_storage = tensor_impl->storage(); + storage_id = getObjectID(ob, t_storage.data()); offset = tensor_impl->storage_offset(); numel = tensor_impl->numel(); itemsize = tensor_impl->itemsize(); diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index ffbb7bb1235a7..3a437fa78229e 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -145,6 +145,10 @@ static StableIValue from_ivalue( list_pointer_to_list_handle(stableivalue_list.release()), extension_build_version); } + case c10::TypeKind::StringType: { + return torch::stable::detail::_from( + ivalue.toStringRef(), extension_build_version); + } default: { TORCH_CHECK( false, @@ -251,6 +255,10 @@ static c10::IValue to_ivalue( TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return ivalue_list; } + case c10::TypeKind::StringType: { + return c10::IValue(torch::stable::detail::_to( + stable_ivalue, extension_build_version)); + } default: { TORCH_CHECK( false, @@ -578,3 +586,34 @@ torch_get_mutable_data_ptr(AtenTensorHandle tensor, void** ret_data_ptr) { *ret_data_ptr = t->mutable_data_ptr(); }); } + +AOTI_TORCH_EXPORT AOTITorchError +torch_new_string_handle(const char* data, size_t length, StringHandle* handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = new std::string(data, length); + *handle = reinterpret_cast(str_ptr); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError torch_delete_string(StringHandle handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + delete str_ptr; + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_length(StringHandle handle, size_t* length) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + *length = str_ptr->length(); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_c_str(StringHandle handle, const char** data) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto str_ptr = reinterpret_cast(handle); + *data = str_ptr->c_str(); + }); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 99b3b435cf550..202ca3ba40c05 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -103,6 +103,27 @@ AOTI_TORCH_EXPORT AOTITorchError torch_get_const_data_ptr( const void** ret_data_ptr // returns borrowed reference ); +struct StringOpaque; +using StringHandle = StringOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError +torch_new_string_handle(const char* data, size_t length, StringHandle* handle); + +AOTI_TORCH_EXPORT AOTITorchError torch_delete_string(StringHandle handle); + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_length(StringHandle handle, size_t* length); + +AOTI_TORCH_EXPORT AOTITorchError +torch_string_c_str(StringHandle handle, const char** data); + +#ifdef USE_CUDA + +AOTI_TORCH_EXPORT AOTITorchError +torch_get_current_cuda_blas_handle(void** ret_handle); + +#endif // USE_CUDA + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus diff --git a/torch/csrc/stable/library.h b/torch/csrc/stable/library.h index dc36c4d182478..ac6d252f757a1 100644 --- a/torch/csrc/stable/library.h +++ b/torch/csrc/stable/library.h @@ -131,6 +131,11 @@ struct UnboxType> { using type = std::vector; }; +template <> +struct UnboxType { + using type = std::string; +}; + template using unbox_type_t = typename UnboxType::type; diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index c90db39cb1b98..923cbf398a104 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -326,24 +326,26 @@ inline uint32_t get_num_threads() { return num_threads; } -// We expect this to be the stable version of the empty op that takes in -// device and dtype parameters. The empty op creates a tensor with uninitialized -// values of the specified size, dtype, and device. -// This function is only available in 2.10 because it uses the stableivalue -// conversion for HeaderOnlyArrayRef, which is only available in 2.10. +// We expect this to be the stable version of the empty.memory_format op that +// takes in device and dtype parameters. This function is only available in 2.10 +// because it uses the stableivalue conversion for HeaderOnlyArrayRef, which +// is only available in 2.10. inline torch::stable::Tensor empty( torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt, + std::optional layout = std::nullopt, std::optional device = std::nullopt, - std::optional pin_memory = std::nullopt) { + std::optional pin_memory = std::nullopt, + std::optional memory_format = + std::nullopt) { const auto num_args = 6; std::array stack{ torch::stable::detail::from(size), torch::stable::detail::from(dtype), - torch::stable::detail::from(std::nullopt), + torch::stable::detail::from(layout), torch::stable::detail::from(device), torch::stable::detail::from(pin_memory), - torch::stable::detail::from(std::nullopt)}; + torch::stable::detail::from(memory_format)}; TORCH_ERROR_CODE_CHECK(torch_call_dispatcher( "aten::empty", "memory_format", stack.data(), TORCH_ABI_VERSION)); return torch::stable::detail::to(stack[0]); diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 15ac8e539e76b..c44e656d88e11 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -5,8 +5,11 @@ #include #include #include +#include +#include #include #include +#include #include #include @@ -61,6 +64,9 @@ struct FromImpl { static_assert( !is_std_vector_v, "std::vector requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert( + !std::is_same_v, + "std::string requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -268,6 +274,68 @@ struct FromImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +// Specialization for c10::Layout => StableIValue +// Note that we call into the shim to translate between the user's +// Layout and libtorch's Layout, which can be different! +using c10::Layout; +template <> +struct FromImpl { + static StableIValue call( + Layout val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case Layout::Strided: + return from(aoti_torch_layout_strided()); + case Layout::Sparse: + return from(aoti_torch_layout_sparse_coo()); + case Layout::SparseCsr: + return from(aoti_torch_layout_sparse_csr()); + case Layout::SparseCsc: + return from(aoti_torch_layout_sparse_csc()); + case Layout::SparseBsr: + return from(aoti_torch_layout_sparse_bsr()); + case Layout::SparseBsc: + return from(aoti_torch_layout_sparse_bsc()); + case Layout::Mkldnn: + return from(aoti_torch_layout__mkldnn()); + case Layout::Jagged: + return from(aoti_torch_layout_jagged()); + default: + STD_TORCH_CHECK( + false, + "Not yet supported Layout, please file an issue describing your use case."); + } + } +}; + +// Specialization for c10::MemoryFormat => StableIValue +// Note that we call into the shim to translate between the user's +// MemoryFormat and libtorch's MemoryFormat, which can be different! +using c10::MemoryFormat; +template <> +struct FromImpl { + static StableIValue call( + MemoryFormat val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case MemoryFormat::Contiguous: + return from(aoti_torch_memory_format_contiguous_format()); + case MemoryFormat::Preserve: + return from(aoti_torch_memory_format_preserve_format()); + case MemoryFormat::ChannelsLast: + return from(aoti_torch_memory_format_channels_last()); + case MemoryFormat::ChannelsLast3d: + return from(aoti_torch_memory_format_channels_last_3d()); + default: + STD_TORCH_CHECK( + false, + "Not yet supported MemoryFormat, please file an issue describing your use case."); + } + } +}; + // Specialization for torch::headeronly::HeaderOnlyArrayRef => StableIValue // Returns a new owning reference of the underlying list. template @@ -285,7 +353,7 @@ struct FromImpl> { torch_list_push_back(new_list_handle, from(elem))); } return from(new_list_handle); - } catch (const std::runtime_error& e) { + } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle)); @@ -330,6 +398,21 @@ struct FromImpl { } }; +// Specialization for std::string, which should return a new owning reference of +// the string +template <> +struct FromImpl { + static StableIValue call( + const std::string& val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + StringHandle handle; + TORCH_ERROR_CODE_CHECK( + torch_new_string_handle(val.c_str(), val.length(), &handle)) + return from(handle); + } +}; + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 // ============================================================================= @@ -343,7 +426,6 @@ struct ToImpl { StableIValue val, [[maybe_unused]] uint64_t extension_build_version, [[maybe_unused]] bool is_internal) { - static_assert(std::is_trivially_copyable_v); // Ensure 2.10+ types don't accidentally use the base case - provide clear // compile-time errors. static_assert( @@ -355,6 +437,10 @@ struct ToImpl { static_assert( !is_std_vector_v, "std::vector requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert( + !std::is_same_v, + "std::string requires TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0"); + static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the // destination. So, we can use a union to work around this lack of @@ -529,6 +615,68 @@ struct ToImpl { // ============================================================================= #if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +// Specialization for StableIValue => c10::Layout +template <> +struct ToImpl { + static Layout call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_layout = to(val); + if (shim_layout == aoti_torch_layout_strided()) { + return Layout::Strided; + } else if (shim_layout == aoti_torch_layout_sparse_coo()) { + return Layout::Sparse; + } else if (shim_layout == aoti_torch_layout_sparse_csr()) { + return Layout::SparseCsr; + } else if (shim_layout == aoti_torch_layout_sparse_csc()) { + return Layout::SparseCsc; + } else if (shim_layout == aoti_torch_layout_sparse_bsr()) { + return Layout::SparseBsr; + } else if (shim_layout == aoti_torch_layout_sparse_bsc()) { + return Layout::SparseBsc; + } else if (shim_layout == aoti_torch_layout__mkldnn()) { + return Layout::Mkldnn; + } else if (shim_layout == aoti_torch_layout_jagged()) { + return Layout::Jagged; + } else { + STD_TORCH_CHECK( + false, + "Not yet supported Layout ", + std::to_string(shim_layout), + ", please file an issue describing your use case."); + } + } +}; + +// Specialization for StableIValue => c10::MemoryFormat +template <> +struct ToImpl { + static MemoryFormat call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_memory_format = to(val); + if (shim_memory_format == aoti_torch_memory_format_contiguous_format()) { + return MemoryFormat::Contiguous; + } else if ( + shim_memory_format == aoti_torch_memory_format_preserve_format()) { + return MemoryFormat::Preserve; + } else if (shim_memory_format == aoti_torch_memory_format_channels_last()) { + return MemoryFormat::ChannelsLast; + } else if ( + shim_memory_format == aoti_torch_memory_format_channels_last_3d()) { + return MemoryFormat::ChannelsLast3d; + } else { + STD_TORCH_CHECK( + false, + "Not yet supported MemoryFormat ", + std::to_string(shim_memory_format), + ", please file an issue describing your use case."); + } + } +}; + // Specialization for StableIValue => std::vector // std::vector should be represented as a StableListHandle // filled with StableIValues @@ -553,7 +701,7 @@ struct ToImpl> { } TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return result; - } catch (const std::runtime_error& e) { + } catch (const std::runtime_error&) { // clean up memory if an exception is thrown, and rethrow TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); throw; @@ -579,6 +727,27 @@ struct ToImpl { } }; +// Specialization for std::string +// Returns a new std::string; the string in val is deleted. +template <> +struct ToImpl { + static std::string call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + StringHandle handle = to(val); + size_t length; + TORCH_ERROR_CODE_CHECK(torch_string_length(handle, &length)); + const char* data; + TORCH_ERROR_CODE_CHECK(torch_string_c_str(handle, &data)); + auto strptr = new std::string(data, length); + + // delete the old string before returning new string + TORCH_ERROR_CODE_CHECK(torch_delete_string(handle)); + return *strptr; + } +}; + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 // ============================================================================= @@ -653,31 +822,11 @@ HIDDEN_NAMESPACE_END(torch, stable, detail) // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -[[deprecated("Use torch::stable::detail::from instead.")]] -inline StableIValue from(T val) { - return torch::stable::detail::from(val); -} - -// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation -// note] -template -[[deprecated("Use torch::stable::detail::from instead.")]] -inline StableIValue from(const std::optional& val) { - return torch::stable::detail::from(val); -} - -// WARNING! Will be removed. Only exists for BC. See [global from/to deprecation -// note] -[[deprecated( - "Use torch::stable::detail::from instead.")]] [[maybe_unused]] inline StableIValue -from(const torch::stable::Tensor& val) { - return torch::stable::detail::from(val); -} +C10_DEPRECATED_MESSAGE("Use torch::stable::detail::from instead.") +auto from = &torch::stable::detail::from; // WARNING! Will be removed. Only exists for BC. See [global from/to deprecation // note] template -[[deprecated("Use torch::stable::detail::to instead.")]] -inline T to(StableIValue val) { - return torch::stable::detail::to(val); -} +C10_DEPRECATED_MESSAGE("Use torch::stable::detail::to instead.") +auto to = &torch::stable::detail::to; diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f7abfece3bc31..706d7940ee785 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>> +// checksum<> // clang-format off #pragma once @@ -10,7 +10,6 @@ #include #include #include -#include #include @@ -191,7 +190,7 @@ inline std::string_view printEnum(const ArgumentKind& e) { case ArgumentKind::POSITIONAL: return "POSITIONAL"; case ArgumentKind::KEYWORD: return "KEYWORD"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -199,7 +198,7 @@ inline void parseEnum(std::string_view s, ArgumentKind& t) { if (s == "UNKNOWN") { t = ArgumentKind::UNKNOWN; return; } if (s == "POSITIONAL") { t = ArgumentKind::POSITIONAL; return; } if (s == "KEYWORD") { t = ArgumentKind::KEYWORD; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class Layout { @@ -224,7 +223,7 @@ inline std::string_view printEnum(const Layout& e) { case Layout::_mkldnn: return "_mkldnn"; case Layout::Strided: return "Strided"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -237,7 +236,7 @@ inline void parseEnum(std::string_view s, Layout& t) { if (s == "SparseBsc") { t = Layout::SparseBsc; return; } if (s == "_mkldnn") { t = Layout::_mkldnn; return; } if (s == "Strided") { t = Layout::Strided; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class MemoryFormat { @@ -256,7 +255,7 @@ inline std::string_view printEnum(const MemoryFormat& e) { case MemoryFormat::ChannelsLast3d: return "ChannelsLast3d"; case MemoryFormat::PreserveFormat: return "PreserveFormat"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -266,7 +265,7 @@ inline void parseEnum(std::string_view s, MemoryFormat& t) { if (s == "ChannelsLast") { t = MemoryFormat::ChannelsLast; return; } if (s == "ChannelsLast3d") { t = MemoryFormat::ChannelsLast3d; return; } if (s == "PreserveFormat") { t = MemoryFormat::PreserveFormat; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } enum class ScalarType { @@ -313,7 +312,7 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -337,7 +336,7 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "FLOAT8E5M2") { t = ScalarType::FLOAT8E5M2; return; } if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -454,7 +453,7 @@ inline std::string_view printEnum(const SymExprHint::Tag& e) { case SymExprHint::Tag::AS_BOOL: return "AS_BOOL"; case SymExprHint::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -462,7 +461,7 @@ inline void parseEnum(std::string_view s, SymExprHint::Tag& t) { if (s == "AS_INT") { t = SymExprHint::Tag::AS_INT; return; } if (s == "AS_BOOL") { t = SymExprHint::Tag::AS_BOOL; return; } if (s == "AS_FLOAT") { t = SymExprHint::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -560,14 +559,14 @@ inline std::string_view printEnum(const SymInt::Tag& e) { case SymInt::Tag::AS_EXPR: return "AS_EXPR"; case SymInt::Tag::AS_INT: return "AS_INT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymInt::Tag& t) { if (s == "AS_EXPR") { t = SymInt::Tag::AS_EXPR; return; } if (s == "AS_INT") { t = SymInt::Tag::AS_INT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -638,14 +637,14 @@ inline std::string_view printEnum(const SymFloat::Tag& e) { case SymFloat::Tag::AS_EXPR: return "AS_EXPR"; case SymFloat::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymFloat::Tag& t) { if (s == "AS_EXPR") { t = SymFloat::Tag::AS_EXPR; return; } if (s == "AS_FLOAT") { t = SymFloat::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -716,14 +715,14 @@ inline std::string_view printEnum(const SymBool::Tag& e) { case SymBool::Tag::AS_EXPR: return "AS_EXPR"; case SymBool::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymBool::Tag& t) { if (s == "AS_EXPR") { t = SymBool::Tag::AS_EXPR; return; } if (s == "AS_BOOL") { t = SymBool::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -866,14 +865,14 @@ inline std::string_view printEnum(const SymIntArgument::Tag& e) { case SymIntArgument::Tag::AS_NAME: return "AS_NAME"; case SymIntArgument::Tag::AS_INT: return "AS_INT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymIntArgument::Tag& t) { if (s == "AS_NAME") { t = SymIntArgument::Tag::AS_NAME; return; } if (s == "AS_INT") { t = SymIntArgument::Tag::AS_INT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -944,14 +943,14 @@ inline std::string_view printEnum(const SymFloatArgument::Tag& e) { case SymFloatArgument::Tag::AS_NAME: return "AS_NAME"; case SymFloatArgument::Tag::AS_FLOAT: return "AS_FLOAT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymFloatArgument::Tag& t) { if (s == "AS_NAME") { t = SymFloatArgument::Tag::AS_NAME; return; } if (s == "AS_FLOAT") { t = SymFloatArgument::Tag::AS_FLOAT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1022,14 +1021,14 @@ inline std::string_view printEnum(const SymBoolArgument::Tag& e) { case SymBoolArgument::Tag::AS_NAME: return "AS_NAME"; case SymBoolArgument::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, SymBoolArgument::Tag& t) { if (s == "AS_NAME") { t = SymBoolArgument::Tag::AS_NAME; return; } if (s == "AS_BOOL") { t = SymBoolArgument::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1136,14 +1135,14 @@ inline std::string_view printEnum(const OptionalTensorArgument::Tag& e) { case OptionalTensorArgument::Tag::AS_TENSOR: return "AS_TENSOR"; case OptionalTensorArgument::Tag::AS_NONE: return "AS_NONE"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } inline void parseEnum(std::string_view s, OptionalTensorArgument::Tag& t) { if (s == "AS_TENSOR") { t = OptionalTensorArgument::Tag::AS_TENSOR; return; } if (s == "AS_NONE") { t = OptionalTensorArgument::Tag::AS_NONE; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1233,11 +1232,11 @@ class Argument { public: enum class Tag { - AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX, AS_INT_LISTS, AS_STRING_TO_ARGUMENT }; private: - std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue> variant_; + std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue, std::vector>, std::unordered_map>> variant_; Tag tag_; public: @@ -1488,6 +1487,24 @@ class Argument { tag_ = Tag::AS_COMPLEX; } + const std::vector>& get_as_int_lists() const { + return std::get<28>(variant_); + } + + void set_as_int_lists(std::vector> def) { + variant_.emplace<28>(std::move(def)); + tag_ = Tag::AS_INT_LISTS; + } + + const std::unordered_map>& get_as_string_to_argument() const { + return std::get<29>(variant_); + } + + void set_as_string_to_argument(std::unordered_map> def) { + variant_.emplace<29>(std::move(def)); + tag_ = Tag::AS_STRING_TO_ARGUMENT; + } + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { if (nlohmann_json_t.tag_ == Tag::AS_NONE) { @@ -1598,6 +1615,14 @@ class Argument { nlohmann_json_j["as_complex"] = nlohmann_json_t.get_as_complex(); return; } + if (nlohmann_json_t.tag_ == Tag::AS_INT_LISTS) { + nlohmann_json_j["as_int_lists"] = nlohmann_json_t.get_as_int_lists(); + return; + } + if (nlohmann_json_t.tag_ == Tag::AS_STRING_TO_ARGUMENT) { + nlohmann_json_j["as_string_to_argument"] = nlohmann_json_t.get_as_string_to_argument(); + return; + } } friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { @@ -1737,6 +1762,16 @@ class Argument { nlohmann_json_t.tag_ = Tag::AS_COMPLEX; return; } + if (nlohmann_json_j.contains("as_int_lists")) { + nlohmann_json_t.variant_.emplace<28>(nlohmann_json_j.at("as_int_lists").template get>>()); + nlohmann_json_t.tag_ = Tag::AS_INT_LISTS; + return; + } + if (nlohmann_json_j.contains("as_string_to_argument")) { + nlohmann_json_t.variant_.emplace<29>(nlohmann_json_j.at("as_string_to_argument").template get>>()); + nlohmann_json_t.tag_ = Tag::AS_STRING_TO_ARGUMENT; + return; + } } }; @@ -1769,8 +1804,10 @@ inline std::string_view printEnum(const Argument::Tag& e) { case Argument::Tag::AS_SYM_FLOATS: return "AS_SYM_FLOATS"; case Argument::Tag::AS_OPTIONAL_TENSOR: return "AS_OPTIONAL_TENSOR"; case Argument::Tag::AS_COMPLEX: return "AS_COMPLEX"; + case Argument::Tag::AS_INT_LISTS: return "AS_INT_LISTS"; + case Argument::Tag::AS_STRING_TO_ARGUMENT: return "AS_STRING_TO_ARGUMENT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -1802,7 +1839,9 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) { if (s == "AS_SYM_FLOATS") { t = Argument::Tag::AS_SYM_FLOATS; return; } if (s == "AS_OPTIONAL_TENSOR") { t = Argument::Tag::AS_OPTIONAL_TENSOR; return; } if (s == "AS_COMPLEX") { t = Argument::Tag::AS_COMPLEX; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + if (s == "AS_INT_LISTS") { t = Argument::Tag::AS_INT_LISTS; return; } + if (s == "AS_STRING_TO_ARGUMENT") { t = Argument::Tag::AS_STRING_TO_ARGUMENT; return; } + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1905,8 +1944,8 @@ class Graph { std::unordered_map sym_int_values; std::unordered_map sym_bool_values; bool is_single_tensor_return = false; - std::unordered_map custom_obj_values; - std::unordered_map sym_float_values; + std::unordered_map custom_obj_values = {}; + std::unordered_map sym_float_values = {}; public: @@ -2128,7 +2167,7 @@ inline std::string_view printEnum(const ConstantValue::Tag& e) { case ConstantValue::Tag::AS_STRING: return "AS_STRING"; case ConstantValue::Tag::AS_BOOL: return "AS_BOOL"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2138,7 +2177,7 @@ inline void parseEnum(std::string_view s, ConstantValue::Tag& t) { if (s == "AS_FLOAT") { t = ConstantValue::Tag::AS_FLOAT; return; } if (s == "AS_STRING") { t = ConstantValue::Tag::AS_STRING; return; } if (s == "AS_BOOL") { t = ConstantValue::Tag::AS_BOOL; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -2466,7 +2505,7 @@ inline std::string_view printEnum(const InputSpec::Tag& e) { case InputSpec::Tag::TOKEN: return "TOKEN"; case InputSpec::Tag::CONSTANT_INPUT: return "CONSTANT_INPUT"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2478,7 +2517,7 @@ inline void parseEnum(std::string_view s, InputSpec::Tag& t) { if (s == "CUSTOM_OBJ") { t = InputSpec::Tag::CUSTOM_OBJ; return; } if (s == "TOKEN") { t = InputSpec::Tag::TOKEN; return; } if (s == "CONSTANT_INPUT") { t = InputSpec::Tag::CONSTANT_INPUT; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -2852,7 +2891,7 @@ inline std::string_view printEnum(const OutputSpec::Tag& e) { case OutputSpec::Tag::TOKEN: return "TOKEN"; case OutputSpec::Tag::PARAMETER_MUTATION: return "PARAMETER_MUTATION"; default: - TORCH_CHECK(false, "Unknown enum value"); + throw std::runtime_error("Unknown enum value"); } } @@ -2865,7 +2904,7 @@ inline void parseEnum(std::string_view s, OutputSpec::Tag& t) { if (s == "USER_INPUT_MUTATION") { t = OutputSpec::Tag::USER_INPUT_MUTATION; return; } if (s == "TOKEN") { t = OutputSpec::Tag::TOKEN; return; } if (s == "PARAMETER_MUTATION") { t = OutputSpec::Tag::PARAMETER_MUTATION; return; } - TORCH_CHECK(false, "Unknown enum value: " + std::string{s}); + throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -3027,8 +3066,8 @@ class GraphModule { Graph graph; GraphSignature signature; std::vector module_call_graph; - std::unordered_map metadata; - std::unordered_map treespec_namedtuple_fields; + std::unordered_map metadata = {}; + std::unordered_map treespec_namedtuple_fields = {}; public: @@ -3109,9 +3148,9 @@ class ExportedProgram { std::unordered_map opset_version; std::unordered_map range_constraints; SchemaVersion schema_version; - std::vector verifiers; + std::vector verifiers = {}; std::string torch_version = "<=2.4"; - std::vector guards_code; + std::vector guards_code = {}; public: diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index e89f7887320a0..5fa0986cc814d 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -1687,7 +1687,9 @@ bool FunctionSignature::parse( if (max_pos_args == 1 && (params[0].type_ == ParameterType::INT_LIST || params[0].type_ == ParameterType::SYM_INT_LIST)) { - allow_varargs_intlist = true; + int64_t failed_idx = -1; + allow_varargs_intlist = is_int_or_symint_list( + args, params[0].size, &failed_idx, &overloaded_args); } if (static_cast(nargs) > max_pos_args && !allow_varargs_intlist) { diff --git a/torch/csrc/utils/schema_info.cpp b/torch/csrc/utils/schema_info.cpp index fb628bec8c654..e3176854f14d2 100644 --- a/torch/csrc/utils/schema_info.cpp +++ b/torch/csrc/utils/schema_info.cpp @@ -250,12 +250,18 @@ std::vector SchemaInfo::getNonDeterministicOps() { "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor", "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::rand_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.generator(Tensor self, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor(Tensor self, Tensor high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_dtype(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randint_like.low_generator_dtype(Tensor self, int low, int high, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor", "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", + "aten::randn_like.generator(Tensor self, *, Generator? generator, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"}; std::vector nondeterministic_ops; diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index c186694df94e7..cc4a47f299444 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -49,6 +49,7 @@ import operator import os import sys +import threading from collections import defaultdict from collections.abc import Callable, Generator, Sequence from types import TracebackType @@ -76,7 +77,11 @@ from torch.nested._internal.nested_int import NestedIntNode from torch.utils import _pytree as pytree from torch.utils._mode_utils import no_dispatch -from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + return_and_correct_aliasing, + TorchDispatchMode, +) from torch.utils.checkpoint import get_device_states, set_device_states @@ -86,6 +91,89 @@ from . import _c10d +def _is_in_fake_tensor_mode() -> bool: + return any( + isinstance(mode, FakeTensorMode) for mode in _get_current_dispatch_mode_stack() + ) + + +def _reduce_multidim_lists( + lists_to_reduce: list[Any], reduce_func: Callable[[list[Any]], Any] +) -> Any: + """ + Reduces a list of multi-dimensional lists, assuming they all have + the exact same shape. + + Args: + lists_to_reduce (list): A list where each item is a multi-dimensional + list (e.g., [md_list_1, md_list_2, ...]). + All inner md_lists must have the same shape. + reduce_func (callable): A function that takes an iterable (list) of + values and returns a single reduced value. + For example: sum, max, min, or + lambda x: sum(x) / len(x) for mean. + + Returns: + A single multi-dimensional list of the same shape as the inputs, + where each value is the result of the reduce_func. + + Raises: + ValueError: If the input list is empty or if shapes are inconsistent + (which may also raise IndexError or TypeError). + """ + if not lists_to_reduce: + raise ValueError("Input 'lists_to_reduce' cannot be empty.") + + # Get the first list to inspect its structure (shape) + first_list = lists_to_reduce[0] + + # Check if the first element of this list is *also* a list. + # This determines if we are at the base case or need to recurse. + if isinstance(first_list[0], list): + # --- RECURSIVE STEP --- + # The elements are lists, so we need to go one level deeper. + + # We find the number of sub-lists from the first list. + # (e.g., for [[1,2], [3,4]], this is 2) + num_sublists = len(first_list) + + result = [] + # Iterate by the index of the sub-lists (e.g., i = 0, then i = 1) + for i in range(num_sublists): + # Build a new list to pass to the recursive call. + # This list will contain the i-th sublist from *each* of the + # input lists. + # e.g., if lists_to_reduce = [ L1, L2 ] and i = 0, + # this creates [ L1[0], L2[0] ] + sublists_to_reduce = [l[i] for l in lists_to_reduce] + + # Recurse and append the result + result.append(_reduce_multidim_lists(sublists_to_reduce, reduce_func)) + return result + else: + # --- BASE CASE --- + # The elements are values (int, float, etc.), not lists. + # We are at the innermost dimension. + + # Find the number of values in the innermost list. + # (e.g., for [1, 2], this is 2) + num_values = len(first_list) + + result = [] + # Iterate by the index of the values (e.g., i = 0, then i = 1) + for i in range(num_values): + # Get the values at this specific position (i) from *all* + # input lists. + # e.g., if lists_to_reduce = [ [1,2], [10,20] ] and i = 0, + # this creates [ 1, 10 ] + values_at_pos = [l[i] for l in lists_to_reduce] + + # Apply the user-provided reduction function to this list of values + # and append the single result. + result.append(reduce_func(values_at_pos)) + return result + + def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: return ( isinstance(op, OpOverload) @@ -256,21 +344,35 @@ def _for_each_rank_run_func( a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args ] - # NB: Before invoking an op we are collecting rng states from CPU and - # CUDA devices such that we can reset to the same before invoking op - # for each rank. This is not very efficient and will likely be revisited - # to support per rank rng state. - rng_state = _get_rng_state() + lm = enabled_local_tensor_mode() + use_per_rank_rng = lm is not None and len(lm._per_rank_rng_states) > 0 + + global_rng_state = None if use_per_rank_rng else _get_rng_state() + flat_rank_rets = {} default_value: Tensor | None = None for r in sorted(ranks): - _set_rng_state(*rng_state) + if use_per_rank_rng: + assert lm is not None + _set_rng_state(*lm._per_rank_rng_states[r]) + else: + assert global_rng_state is not None + _set_rng_state(*global_rng_state) + rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) - rank_ret = func(*rank_args, **rank_kwargs) + if func is torch.ops.aten.hash_tensor.default and rank_args[0].numel() == 0: + # Special case for empty tensors, hash_tensor returns an empty tensor + rank_ret = torch.empty(0, dtype=torch.uint64, device=rank_args[0].device) + else: + rank_ret = func(*rank_args, **rank_kwargs) flat_rank_rets[r] = rank_ret + if use_per_rank_rng: + assert lm is not None + lm._per_rank_rng_states[r] = _get_rng_state() + if default_value is None and func is torch.ops.aten.split.Tensor: # If split happens over the dimension smaller than the number of chunks # it is possible that some ranks will produce shorter lists of chunks. @@ -365,6 +467,12 @@ def sym_max( } ) + def sym_sum(self, other: Any) -> "LocalIntNode | ConstantIntNode": + t = LocalIntNode(dict.fromkeys(self._local_ints, 0)) + for o in other: + t = t.add(o) + return t + def neg(self) -> "LocalIntNode | ConstantIntNode": return LocalIntNode({r: -self._local_ints[r] for r in self._local_ints}) @@ -437,6 +545,247 @@ def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": return ConstantIntNode(num) +class _LocalDeviceHandle: + """ + Wrapper around device module (e.g., torch.cuda) with automatic LocalTensor semantics. + + This class wraps device modules and automatically handles per-rank operations in + LocalTensor mode: + - get_rng_state() returns a LocalTensor with per-rank states + - set_rng_state(LocalTensor) sets per-rank states + + When not in LocalTensor mode, it delegates directly to the underlying device handle. + """ + + def __init__(self, device_handle, device_type: str): + """ + Initialize the local device handle wrapper. + + Args: + device_handle: The underlying device module (e.g., torch.cuda) + device_type: Device type string (e.g., "cuda", "cpu") + """ + self._device_handle = device_handle + self._device_type = device_type + + def get_rng_state(self): + """ + Get RNG state, automatically returning LocalTensor in LocalTensor mode. + + Returns: + LocalTensor in LocalTensor mode, regular Tensor otherwise + """ + lm = enabled_local_tensor_mode() + if not lm: + return self._device_handle.get_rng_state() + + original_state = _get_rng_state() + per_rank_states = {} + + try: + for rank in lm.ranks: + # We need to set-then-get instead of directly copying lm._per_rank_rng_states[rank] + # because they have different structures: + # - lm._per_rank_rng_states[rank] is a tuple: (cpu_state, {device_idx: cuda_state}) + # - self._device_handle.get_rng_state() returns just the device-specific tensor + # So we temporarily restore the full RNG state (CPU + all CUDA devices) for this rank, + # then extract only the specific device's state tensor that we need. + if rank in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[rank]) + + per_rank_states[rank] = self._device_handle.get_rng_state() + finally: + _set_rng_state(*original_state) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(per_rank_states) + + def set_rng_state(self, state): + """ + Set RNG state, automatically handling LocalTensor input. + + Args: + state: Regular Tensor or LocalTensor with per-rank states + """ + if isinstance(state, LocalTensor): + lm = enabled_local_tensor_mode() + assert lm is not None + + # Similar to get_rng_state but in reverse: we need to convert from + # device-specific tensor format to full state tuple format. + # - state._local_tensors[rank] contains just the device-specific RNG state tensor + # - lm._per_rank_rng_states[rank] needs a tuple: (cpu_state, {device_idx: cuda_state}) + # So we set the device's state with the rank-specific tensor, then _get_rng_state() + # captures both CPU and CUDA states into the tuple format that _per_rank_rng_states expects. + for rank, rank_state in state._local_tensors.items(): + self._device_handle.set_rng_state(rank_state.to("cpu")) + lm._per_rank_rng_states[rank] = _get_rng_state() + else: + self._device_handle.set_rng_state(state.to("cpu")) + + def __getattr__(self, name): + """Delegate all other attributes to the underlying device module.""" + return getattr(self._device_handle, name) + + +class _LocalOffsetBasedRNGTracker: + """ + LocalTensor-specific RNG tracker for DTensor random operations. + + This class manages per-rank RNG states when running in LocalTensor mode, + using _LocalPhiloxState to track different offsets for each virtual rank. + It is instantiated and used by OffsetBasedRNGTracker when in LocalTensor mode. + + Much of this is derived from OffsetBasedRNGTracker: + https://github.com/pytorch/pytorch/blob/402c46503002f98ccfc023a733081fb0719223a1/torch/distributed/tensor/_random.py#L182 + """ + + def __init__(self, device_type: str = "cuda"): + """Initialize the LocalTensor RNG tracker.""" + from torch.distributed.device_mesh import _get_device_handle + + self._device_type = device_type + self._device_handle = _LocalDeviceHandle( + _get_device_handle(device_type), device_type + ) + self.distribute_region_enabled = True + self._device_mesh = None + + @property + def _device(self): + return torch.device(self._device_type, torch.cuda.current_device()) + + def _set_pre_op_offset(self, state, spec) -> None: + """Compute and set per-rank offsets before the random operation.""" + from torch.distributed.tensor._ops.utils import prod + from torch.distributed.tensor._utils import ( + _compute_local_shape_and_global_offset, + ) + from torch.distributed.tensor.placement_types import Shard + + lm = enabled_local_tensor_mode() + assert lm is not None + + state._per_rank_offsets = {} + + for rank in lm.ranks: + # compute this rank's coordinate in the mesh + mesh_coords = [] + for mesh_dim_idx in range(spec.mesh.ndim): + mesh_dim_size = spec.mesh.size(mesh_dim_idx) + # calculate rank's coordinate in this mesh dimension + num_chunks_after = 1 + for j in range(mesh_dim_idx + 1, spec.mesh.ndim): + num_chunks_after *= spec.mesh.size(j) + coord = (rank // num_chunks_after) % mesh_dim_size + mesh_coords.append(coord) + + # compute local shape and global offset for this rank + local_shape, global_offset = _compute_local_shape_and_global_offset( + spec.shape, spec.mesh.shape, mesh_coords, spec.placements + ) + + # compute shard offset based on placements + shard_offset = 1 + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + shard_offset *= global_offset[shard_dim] + 1 + + # get current offset for this rank + current_offset = int( + state._per_rank_states[rank][8:].view(dtype=torch.int64).item() + ) + + # compute local size + local_size = prod(local_shape) + + # compute new offset (must be multiple of 4) + shard_linear_idx = shard_offset - 1 + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state._per_rank_offsets[rank] = current_offset + offset_incr + + def _set_post_op_offset(self, state, spec, old_offset) -> None: + """Set per-rank offsets after the random operation.""" + from torch.distributed.tensor._ops.utils import prod + + lm = enabled_local_tensor_mode() + assert lm is not None + + dtensor_shape = spec.shape + numel = prod(dtensor_shape) + # offset must be multiple of 4 + numel = (numel + 3) // 4 * 4 + + if not hasattr(state, "_per_rank_offsets"): + state._per_rank_offsets = {} + + # handle LocalIntNode old_offset (different values per rank) + if isinstance(old_offset, SymInt) and isinstance(old_offset.node, LocalIntNode): + for rank in lm.ranks: + rank_old_offset = old_offset.node._local_ints[rank] + state._per_rank_offsets[rank] = rank_old_offset + numel + else: + # same old_offset for all ranks + old_offset_int = ( + int(old_offset) if isinstance(old_offset, SymInt) else old_offset + ) + for rank in lm.ranks: + state._per_rank_offsets[rank] = old_offset_int + numel + + @contextlib.contextmanager + def _distribute_region(self, spec, generator=None): + """Context manager for LocalTensor mode distribute region.""" + lm = enabled_local_tensor_mode() + assert lm is not None + + # get base state + if generator is not None: + base_state_tensor = generator.get_state() + per_rank_states = {rank: base_state_tensor.clone() for rank in lm.ranks} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + base_state_tensor = LocalTensor(per_rank_states) + else: + base_state_tensor = self._device_handle.get_rng_state() + + state = _LocalPhiloxState(base_state_tensor) + + if self.distribute_region_enabled: + # sync to rank 0's state if no explicit generator + if generator is None: + rank_0_state = lm._per_rank_rng_states[0] + rank_0_cpu, rank_0_cuda = rank_0_state + + if self._device.type == "cuda": + assert self._device.index in rank_0_cuda + rank_0_device_state = rank_0_cuda[self._device.index] + else: + rank_0_device_state = rank_0_cpu + + from torch.distributed.tensor._random import _PhiloxState + + rank_0_philox = _PhiloxState(rank_0_device_state) + state.seed = rank_0_philox.seed + state.offset = rank_0_philox.offset + + old_offset = state.offset + self._set_pre_op_offset(state, spec) + state.apply_to_local_tensor_mode(self._device_handle) + + try: + yield + finally: + self._set_post_op_offset(state, spec, old_offset) + state.apply_to_local_tensor_mode(self._device_handle) + else: + yield + + # maybe reset generator to rank 0's state + if generator is not None: + rank_0_state = state._per_rank_states[0] + generator.set_state(rank_0_state) + + _LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" @@ -597,6 +946,7 @@ def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": local_tensors_copy = { r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors_copy, self.requires_grad) def __repr__(self) -> str: # type: ignore[override] @@ -636,6 +986,7 @@ def __tensor_unflatten__( local_tensors = { _from_local_tensor_attr(a): t for a, t in inner_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors) @classmethod @@ -676,9 +1027,7 @@ def __torch_dispatch__( # type: ignore[override] with LocalTensorMode(local_tensor._ranks): return func(*args, **kwargs) - def numpy( - self, *, force: bool = False - ) -> np.ndarray: # pyrefly: ignore # missing-attribute + def numpy(self, *, force: bool = False) -> Any: if HAS_NUMPY: return self.reconcile().numpy(force=force) else: @@ -708,10 +1057,24 @@ def is_contiguous( def tolist(self) -> list[Any]: """ - Reconcile and convert result to list. + Try to reconcile, if successful convert to list, otherwise if dtype is integer, + convert to list of local integers. """ + equal_obj = self._equal_local_tensors() + if isinstance(equal_obj, torch.Tensor): + return equal_obj.tolist() + if isinstance(equal_obj, torch.Size): + if not self.dtype.is_floating_point and not self.dtype.is_complex: + ranks = sorted(self._ranks) + local_lists = [self._local_tensors[r].tolist() for r in ranks] + return _reduce_multidim_lists( + local_lists, + lambda values: torch.SymInt( + LocalIntNode(dict(zip(ranks, values, strict=True))) + ), + ) - return self.reconcile().tolist() + raise RuntimeError("Cannot convert local tensor to list") def reconcile(self) -> torch.Tensor: """ @@ -725,16 +1088,23 @@ def reconcile(self) -> torch.Tensor: """ # Force all local tensor shards across ranks to be the same - it = iter(self._local_tensors.values()) - t1 = next(it) - for t2 in it: - assert torch.equal(t1, t2), ( - "LocalTensor shards must be the same to reconcile" - ) - cl = t1.clone().detach() + equal_obj = self._equal_local_tensors() + assert isinstance(equal_obj, torch.Tensor), ( + "LocalTensor shards must be the same to reconcile" + ) + cl = equal_obj.clone().detach() cl.requires_grad_(self.requires_grad) return cl + def _equal_local_tensors(self) -> torch.Tensor | torch.Size | None: + it = iter(self._local_tensors.values()) + t1 = next(it) + if all(t2.equal(t1) for t2 in it): + return t1 + if all(t2.shape == t1.shape for t2 in it): + return t1.shape + return None + def _sync_meta(self) -> None: with no_dispatch(): (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( @@ -743,7 +1113,18 @@ def _sync_meta(self) -> None: self._size = shape -_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +_GLOBAL_LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] +# When running under local runner each thread must create its own local tensor mode +# so that they do not interfere with each other. +_THREAD_LOCAL_TENSOR_MODE: threading.local = threading.local() + + +def get_local_tensor_mode_list() -> list["LocalTensorMode"]: + if not hasattr(_THREAD_LOCAL_TENSOR_MODE, "value"): + _THREAD_LOCAL_TENSOR_MODE.value = [] + if len(_THREAD_LOCAL_TENSOR_MODE.value) > 0: + return _THREAD_LOCAL_TENSOR_MODE.value + return _GLOBAL_LOCAL_TENSOR_MODE class LocalTensorMode(TorchDispatchMode): @@ -774,11 +1155,27 @@ def __init__(self, ranks: Union[int, frozenset[int]]): self.ranks = ranks self._disable = False self._old_get_coordinate = None + self._old_torch_manual_seed: Any = None + self._old_torch_initial_seed: Any = None + self._per_rank_rng_states: dict[ + int, tuple[torch.Tensor, dict[int, torch.Tensor]] + ] = {} def __enter__(self) -> "LocalTensorMode": self._disable = False self._patch_device_mesh() - _LOCAL_TENSOR_MODE.append(self) + self._patch_random_functions() + get_local_tensor_mode_list().append(self) + + # _distribute_region will compute correct per-shard offsets + # but we want all ranks to start with the same state + if not _is_in_fake_tensor_mode(): + cpu_state, cuda_states = _get_rng_state() + for rank in self.ranks: + self._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) return super().__enter__() @@ -790,7 +1187,8 @@ def __exit__( ) -> None: self._disable = True self._unpatch_device_mesh() - _LOCAL_TENSOR_MODE.pop() + self._unpatch_random_functions() + get_local_tensor_mode_list().pop() super().__exit__(exc_type, exc_val, exc_tb) def __torch_dispatch__( @@ -855,6 +1253,10 @@ def __torch_dispatch__( return _c10d._local_all_gather_(*args, **kwargs) elif func is torch.ops.c10d.allgather_into_tensor_coalesced_.default: return _c10d._local_allgather_into_tensor_coalesced_(*args, **kwargs) + elif func is torch.ops.c10d._allgather_base_.default: + return _c10d._local_allgather_base_(*args, **kwargs) + elif func is torch.ops.c10d._reduce_scatter_base_.default: + return _c10d._local_reduce_scatter_base_(*args, **kwargs) elif func is torch.ops.c10d.gather_.default: return _c10d._local_gather_(*args, **kwargs) elif func is torch.ops.c10d.alltoall_.default: @@ -880,6 +1282,10 @@ def __torch_dispatch__( return _c10d._local_functional_all_gather_into_tensor(*args, **kwargs) elif func is torch.ops._c10d_functional.reduce_scatter_tensor.default: return _c10d._local_functional_reduce_scatter_tensor(*args, **kwargs) + elif func is torch.ops._c10d_functional.all_to_all_single.default: + return _c10d._local_functional_all_to_all_single(*args, **kwargs) + elif func is torch.ops._c10d_functional.wait_tensor.default: + return _c10d._local_functional_wait_tensor(*args, **kwargs) else: with LocalTensorMode(self.ranks): return func._op_dk( @@ -936,6 +1342,7 @@ def tensor_map( m = cb(r, tensor._local_tensors[r]) if m is not None: results[r] = m + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(results) def _patch_device_mesh(self) -> None: @@ -949,6 +1356,87 @@ def _unpatch_device_mesh(self) -> None: # pyrefly: ignore [bad-assignment] self._old_get_coordinate = None + def _patch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is None: + self._old_torch_manual_seed = torch.random.manual_seed + torch.random.manual_seed = _LocalRandom.torch_manual_seed + torch.manual_seed = _LocalRandom.torch_manual_seed + + if self._old_torch_initial_seed is None: + self._old_torch_initial_seed = torch.random.initial_seed + torch.random.initial_seed = _LocalRandom.torch_initial_seed + torch.initial_seed = _LocalRandom.torch_initial_seed + + def _unpatch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is not None: + torch.random.manual_seed = self._old_torch_manual_seed + torch.manual_seed = self._old_torch_manual_seed + self._old_torch_manual_seed = None + + if self._old_torch_initial_seed is not None: + torch.random.initial_seed = self._old_torch_initial_seed + torch.initial_seed = self._old_torch_initial_seed + self._old_torch_initial_seed = None + + +class _LocalRandom: + """ + Holds implementations of random functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def torch_manual_seed(seed) -> torch._C.Generator: + """LocalTensor-aware version of torch.random.manual_seed.""" + if ( + (lm := enabled_local_tensor_mode()) + and isinstance(seed, torch.SymInt) + and isinstance(seed.node, LocalIntNode) + ): + from torch.random import _manual_seed_impl + + for rank in sorted(lm.ranks): + rank_seed = seed.node._local_ints[rank] + _manual_seed_impl(rank_seed, update_local_tensor_states=False) + lm._per_rank_rng_states[rank] = _get_rng_state() + return torch.random.default_generator + from torch.random import _manual_seed_impl + + result = _manual_seed_impl(seed, update_local_tensor_states=False) + + if lm is not None and len(lm._per_rank_rng_states) > 0: + cpu_state, cuda_states = _get_rng_state() + for rank in lm.ranks: + lm._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return result + + @staticmethod + def torch_initial_seed(): + """LocalTensor-aware version of torch.random.initial_seed.""" + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + return torch.random.default_generator.initial_seed() + rank_seeds = {} + + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = torch.random.default_generator.initial_seed() + + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + + return torch.random.default_generator.initial_seed() + class _LocalDeviceMesh: """ @@ -963,7 +1451,7 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]: # doing this because when submesh is created it is created for a particular # rank (therefore below we are patching get_rank method). We are trying to # limit the invasiveness of local tensor. - lm = local_tensor_mode() + lm = enabled_local_tensor_mode() assert lm is not None, "Unexpectedly not in LocalTensorMode" coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] @@ -1019,8 +1507,25 @@ def local_tensor_mode() -> Optional[LocalTensorMode]: Returns: Optional[LocalTensorMode]: The current LocalTensorMode if active, else None. """ - if len(_LOCAL_TENSOR_MODE) > 0: - return _LOCAL_TENSOR_MODE[-1] + local_tensor_mode_list = get_local_tensor_mode_list() + if len(local_tensor_mode_list) > 0: + return local_tensor_mode_list[-1] + return None + + +def enabled_local_tensor_mode() -> Optional[LocalTensorMode]: + """ + Returns the current active LocalTensorMode only if it's enabled. + + This is a convenience function that combines the common pattern of checking + if local_tensor_mode() is not None and not disabled. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active and enabled, else None. + """ + lm = local_tensor_mode() + if lm is not None and not lm._disable: + return lm return None @@ -1048,8 +1553,7 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] - lm = local_tensor_mode() - if lm is None or lm._disable: + if not (lm := enabled_local_tensor_mode()): return func(*args, **kwargs) ret = None with lm.disable(): @@ -1068,6 +1572,73 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: return lm.disable() if lm is not None else contextlib.nullcontext() +def maybe_enable_local_tracker( + device_type: str, distribute_region_enabled: bool, spec, generator +): + """ + Returns a context manager for LocalTensor-mode RNG tracking if local tensor mode is enabled. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + distribute_region_enabled: Whether distribute region is enabled + spec: The DTensorSpec + generator: Optional torch.Generator + + Returns: + Context manager from local_tracker._distribute_region if local tensor mode is enabled, + otherwise None. + """ + if enabled_local_tensor_mode(): + local_tracker = _LocalOffsetBasedRNGTracker(device_type) + local_tracker.distribute_region_enabled = distribute_region_enabled + return local_tracker._distribute_region(spec, generator) + + return None + + +def get_generator_seed_for_device_type(device_type: str): + """ + Gets the generator seed for a specific device type, handling LocalTensor mode appropriately. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + + Returns: + If in LocalTensor mode with per-rank RNG states: + - Returns int if all ranks have the same seed + - Returns SymInt(LocalIntNode) if ranks have different seeds + Otherwise: + - Returns int seed from the device's RNG state + """ + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + device_module = torch.get_device_module(device_type) + + original_state = _get_rng_state() + + rank_seeds = {} + try: + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = int( + device_module.get_rng_state()[:8].view(torch.int64).item() + ) + finally: + # restore original state + _set_rng_state(*original_state) + + unique_seeds = set(rank_seeds.values()) + if len(unique_seeds) == 1: + return next(iter(unique_seeds)) + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + else: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + + import threading from queue import Queue @@ -1158,7 +1729,6 @@ def _get_recv_object(self, src: int, dst: int) -> object | None: def _signal_send(self, src: int, dst: int, obj: object) -> None: assert obj is not None, "Cannot signal None" - self._assert_holds_run_lock() # Only a single thread a time executes so it is safe to mutate # read objects queue (executing thread is already holding the lock) self._recv_objects[dst][src].put(obj) @@ -1167,7 +1737,6 @@ def _signal_send(self, src: int, dst: int, obj: object) -> None: self._run_cond.notify_all() def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: - self._assert_holds_run_lock() # Wait for the object to be available while True: obj = self._get_recv_object(src, dst) @@ -1183,3 +1752,114 @@ def current() -> "LocalRunnerMode": global _LOCAL_RUNNER_MODE assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" return _LOCAL_RUNNER_MODE + + +class _LocalPhiloxState: + """ + LocalTensor-aware version of _PhiloxState that manages per-rank RNG states. + This class handles the case where the generator state is a LocalTensor, allowing + different offsets and seeds for different virtual ranks. + + Note: This is designed to be used as a drop-in replacement for _PhiloxState + when working with LocalTensors in the DTensor random ops implementation. + """ + + def __init__(self, state: torch.Tensor): + assert isinstance(state, LocalTensor), ( + "_LocalPhiloxState requires a LocalTensor" + ) + self._local_tensor = state + self._per_rank_states = { + rank: local_state.to("cpu") + for rank, local_state in state._local_tensors.items() + } + + @property + def state(self): + return LocalTensor(self._per_rank_states) # type: ignore[name-defined] + + @property + def offset(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + offsets = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + offsets[rank] = rank_philox.offset + + if len(set(offsets.values())) == 1: + return next(iter(offsets.values())) + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return SymInt(LocalIntNode(offsets)) + + @offset.setter + def offset(self, offset: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(offset, SymInt) and isinstance(offset.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_offset = offset.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.offset = rank_offset + else: + offset_int = int(offset) if isinstance(offset, SymInt) else offset + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.offset = offset_int + + @property + def seed(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + seeds = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + seeds[rank] = rank_philox.seed + + if len(set(seeds.values())) == 1: + return next(iter(seeds.values())) + return SymInt(LocalIntNode(seeds)) + + @seed.setter + def seed(self, seed: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(seed, SymInt) and isinstance(seed.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_seed = seed.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.seed = rank_seed + else: + seed_int = int(seed) if isinstance(seed, SymInt) else seed + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.seed = seed_int + + def apply_to_local_tensor_mode(self, device_handle) -> None: + """ + Apply per-rank RNG states to the LocalTensorMode's tracked states. + This updates both the device RNG state and the LocalTensorMode's _per_rank_rng_states. + + Args: + device_handle: The device handle to use for setting RNG state (_LocalDeviceHandle) + """ + if not enabled_local_tensor_mode(): + return + + assert hasattr(self, "_per_rank_offsets") + + for rank in sorted(self._per_rank_states.keys()): + offset_value = self._per_rank_offsets[rank] + if isinstance(offset_value, SymInt): + if isinstance(offset_value.node, LocalIntNode): + offset_value = offset_value.node._local_ints[rank] + else: + offset_value = int(offset_value) + + offset_tensor = torch.tensor( + [offset_value], dtype=torch.uint64, device="cpu" + ).view(torch.uint8) + self._per_rank_states[rank][8:] = offset_tensor + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + device_handle.set_rng_state(LocalTensor(self._per_rank_states)) diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 0b63330dfafce..a6a8c41103c9f 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -216,6 +216,69 @@ def _local_functional_shard_dim_alltoall( return output +def _local_functional_all_to_all_single( + tensor: torch.Tensor, + output_split_sizes: list[torch.SymInt], + input_split_sizes: list[torch.SymInt], + group_name: str, +) -> torch.Tensor: + # "all_to_all_single(Tensor input, SymInt[] output_split_sizes, SymInt[] input_split_sizes, str group_name) -> Tensor" + from . import LocalIntNode, LocalTensor + + ranks, group_offsets, offset = _prepare_collective_groups( + _resolve_process_group(group_name) + ) + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + split_local_sizes: dict[int, list[int]] = {} + for input_split_size in input_split_sizes: + if isinstance(input_split_size, torch.SymInt) and isinstance( + input_split_size.node, LocalIntNode + ): + local_ints = dict(input_split_size.node._local_ints.items()) + else: + local_ints = { + rank: int(input_split_size) for rank in tensor._local_tensors.keys() + } + for rank, split_size in local_ints.items(): + if rank not in split_local_sizes: + split_local_sizes[rank] = [] + split_local_sizes[rank].append(split_size) + + split_local_tensors: dict[int, list[torch.Tensor]] = {} + + for rank, split_sizes in split_local_sizes.items(): + split_local_tensors[rank] = list( + torch.split(tensor._local_tensors[rank], split_sizes) + ) + + output_local_tensors: dict[int, torch.Tensor] = {} + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + for i, dst in enumerate(group_ranks): + splits = [] + for j, src in enumerate(group_ranks): + splits.append(split_local_tensors[src][i]) + output_local_tensors[dst] = torch.cat(splits) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + output = LocalTensor(output_local_tensors) + + return output + + +def _local_functional_wait_tensor(tensor: torch.Tensor) -> torch.Tensor: + # "wait_tensor(Tensor input) -> Tensor" + from . import LocalTensor + + assert isinstance(tensor, LocalTensor), "Input tensor must be a LocalTensor" + + return tensor + + def _local_broadcast_( tensors: list[torch.Tensor], process_group_so: ScriptObject, @@ -423,6 +486,82 @@ def _local_reduce_scatter_tensor_coalesced_( return work_so +def _local_allgather_base_( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup + # process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + from . import LocalTensor + + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + gathered_tensor = torch.cat(gathered_tensors, dim=0) + + for rank_i in group_ranks: + output_tensor._local_tensors[rank_i].copy_(gathered_tensor) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + +def _local_reduce_scatter_base_( # type: ignore[no-untyped-def] + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + process_group_so: ScriptObject, + reduce_op_so: ScriptObject, + async_op: bool = True, + timeout: int = -1, +) -> tuple[torch.Tensor, ScriptObject]: + # "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, + # __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, + # bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)" + + from . import LocalTensor + + reduce_op = reduce_op_so.op() # type: ignore[attr-defined] + ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) + + assert isinstance(output_tensor, LocalTensor), "Output tensor must be a LocalTensor" + assert isinstance(input_tensor, LocalTensor), "Input tensor must be a LocalTensor" + + for group_offset in group_offsets: + group_ranks = [group_offset + r for r in ranks] + + gathered_tensors = [] + for rank_i in group_ranks: + gathered_tensors.append(input_tensor._local_tensors[rank_i]) + + reduced_tensor = _local_reduce(reduce_op, gathered_tensors) + + scattered_tensor = torch.split( + reduced_tensor, + reduced_tensor.size(0) // len(group_ranks), + dim=0, + ) + + for i, rank_i in enumerate(group_ranks): + output_tensor._local_tensors[rank_i].copy_(scattered_tensor[i].clone()) + + work = FakeWork() + work_so = Work.boxed(work) + return output_tensor, work_so + + def _local_all_gather_( output_tensors: list[list[torch.Tensor]], input_tensors: list[torch.Tensor], diff --git a/torch/distributed/_shard/sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py index 26788f4054bce..9825edd352c1f 100644 --- a/torch/distributed/_shard/sharding_spec/_internals.py +++ b/torch/distributed/_shard/sharding_spec/_internals.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs import math +import sys +from bisect import bisect_right, insort from typing import Optional from torch.distributed._shard.metadata import ShardMetadata @@ -27,31 +29,48 @@ def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetad def _find_nd_overlapping_shards( shards: list[ShardMetadata], sharded_dims: list[int] ) -> Optional[tuple[int, int]]: - # Each rank has len(sharded_dims) tuples. Each tuple represent the - # [begin, end] (inclusive) pair of that dimension. - shard_intervals = [ - [ - (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1) - for dim in sharded_dims - ] - for s in shards - ] - - for i in range(len(shards)): - shard_i = shard_intervals[i] - for j in range(i + 1, len(shards)): - shard_j = shard_intervals[j] - # For each dim of each shard, check if one shard resides on the other - # end of second shard with respect to that dim. As an example for a 2D - # shard, we would check if one shard is above or on the left of the - # other shard. - overlap = True - for interval_i, interval_j in zip(shard_i, shard_j): - if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]: - overlap = False - break - if overlap: - return (i, j) + """Find overlapping shards using sweep-line algorithm.""" + if len(shards) <= 1: + return None + + dims = len(sharded_dims) + if dims == 0: + return None + + sweep_dim_idx = 0 + if dims > 1: + max_size = 0 + for i, dim in enumerate(sharded_dims): + dim_size = shards[0].shard_offsets[dim] + shards[0].shard_sizes[dim] + if dim_size > max_size: + max_size = dim_size + sweep_dim_idx = i + sweep_dim = sharded_dims[sweep_dim_idx] + + sorted_indices = sorted( + range(len(shards)), + key=lambda idx: ( + shards[idx].shard_offsets[sweep_dim], + *(shards[idx].shard_offsets[d] for d in sharded_dims if d != sweep_dim), + ), + ) + active: list[tuple[int, int]] = [] + + for idx in sorted_indices: + current = shards[idx] + start = current.shard_offsets[sweep_dim] + end = start + current.shard_sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = shards[other_idx] + + if _check_shard_metadata_pair_overlap(current, other): + return (other_idx, idx) + insort(active, (end, idx)) return None @@ -112,10 +131,8 @@ def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): # using a O(nlogn) overlapping interval algorithm. pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) else: - # Shards are partitioned over more than one dimension. Fall back to - # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist - # for 2D overlap, the implementation is not trivial and may not justify - # the time saving in most cases. + # Shards are partitioned over more than one dimension. + # Use sweep-line algorithm for O(n log n) complexity. pair = _find_nd_overlapping_shards(shards, sharded_dims) if pair: diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index b897e51cac9f3..caf399cf6a802 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -1,6 +1,4 @@ # Owner(s): ["module: unknown"] -import math -import os from collections import defaultdict from typing import Any, TYPE_CHECKING from typing_extensions import Self @@ -8,75 +6,22 @@ import torch import torch.utils._pytree as pytree from torch._guards import active_fake_mode -from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.mod_tracker import ModTracker from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.flop_counter import flop_registry +from torch.utils._runtime_estimation import ( + _FLOAT_TYPES, + _IGNORE_OPS, + _VIEW_OPS, + get_compute_time, + get_transfer_time, +) if TYPE_CHECKING: from collections.abc import Callable - -aten = torch.ops.aten - -# This value is hard-coded here: -# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 -_PYTORCH_MIN_ALLOCATE = ( - 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 -) - -# No fall-back kernel needed/exists for view ops -_VIEW_OPS = { - aten.lift_fresh, - aten.t, - aten.transpose, - aten.view, - aten.detach, - aten._unsafe_view, - aten.split, - aten.adjoint, - aten.as_strided, - aten.diagonal, - aten.expand, - aten.expand_as, - aten.movedim, - aten.permute, - aten.select, - aten.squeeze, - aten.mT, - aten.mH, - aten.real, - aten.imag, - aten.view_as, - aten.unflatten, - aten.unfold, - aten.unbind, - aten.unsqueeze, - aten.vsplit, - aten.hsplit, - aten.split_with_sizes, - aten.swapaxes, - aten.swapdims, - aten.chunk, -} -# We can ignore benchmarking tensor create ops -_CREATE_OPS = { - aten.randint, - aten.randn, - aten.rand, - aten.randn_like, - aten.rand_like, - aten.randint_like, - aten.arange, - aten.ones_like, - aten.zeros_like, -} - -_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS - __all__ = ["RuntimeEstimator"] @@ -125,12 +70,6 @@ class RuntimeEstimator(TorchDispatchMode): runtime_estimator.display_modulewise_stats() """ - _float_types: set[torch.dtype] = { - torch.float16, - torch.bfloat16, - torch.float32, - torch.float64, - } _no_fallback_kernel: set[torch._ops._OpNamespace] = set() fake_mode: FakeTensorMode @@ -186,7 +125,7 @@ def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] def to_real_tensor(e): # type: ignore[no-untyped-def] if cls.fake_mode.is_our_fake(e): - if e.dtype in cls._float_types: + if e.dtype in _FLOAT_TYPES: out = torch.rand_like(e, device=e.fake_device) else: out = torch.ones_like(e, device=e.fake_device) @@ -297,79 +236,6 @@ def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: i "Roofline estimation needs to access CUDA capabilities to make estimations" ) - def get_num_bytes(t: torch.Tensor) -> int: - """ - Calculates the memory consumption of a tensor. - - Args: - t (torch.Tensor): The input tensor. - - Returns: - int: The memory consumption of the tensor in bytes. - """ - num_bytes = t.untyped_storage().nbytes() - mem_consumed = ( - math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE - ) - return mem_consumed - - def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] - """ - Estimates the compute time of an aten operator. - - Args: - func_packet: The operator overload packet. - args: The arguments to the operator. - kwargs: The keyword arguments to the operator. - out: The output of the operator. - out_dtypes: The output data types. - - Returns: - float: The estimated compute time in nanoseconds. - """ - if func_packet in flop_registry: - assert len(out_dtypes) == 1, ( - f"Only support single out dtype got {out_dtypes} for {func_packet}" - ) - dtype = out_dtypes.pop() - # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s - peak_gpu_flops = get_device_tflops(dtype) * 1e15 - # We can expect to achieve 75% of theoretical peak flops - factor = 0.75 - peak_empirical_flops = factor * peak_gpu_flops - flop_count_func = flop_registry[func_packet] - # We divide by a factor of 2 to get the MACs (multiply and accumulate) - flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 - # We multiply by 1e9 to get the time in nano seconds - compute_time = (flop_count / peak_empirical_flops) * 1e9 - return compute_time - return 0.0 - - def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] - """ - Estimates the memory transfer time of input and output tensors. - - Args: - flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. - flat_outs (List[torch.Tensor]): The flat list of outputs. - - Returns: - float: The estimated memory transfer time in nanoseconds. - """ - gpu_memory_bandwidth = get_gpu_dram_gbps() - read_bytes = sum( - get_num_bytes(t) - for t in flat_args_kwargs - if isinstance(t, torch.Tensor) - ) - write_bytes = sum( - get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) - ) - counted_bytes = read_bytes + write_bytes - # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds - transfer_time = counted_bytes / gpu_memory_bandwidth - return transfer_time - # Roofline Cost Model Explanation # The roofline cost model estimates the execution time of an operator based on @@ -406,7 +272,7 @@ def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no- out_dtypes = { t.dtype for t in flat_outs - if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + if isinstance(t, torch.Tensor) and t.dtype in _FLOAT_TYPES } args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) diff --git a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py index 3ce067f6cddc0..eae76e8cc72af 100644 --- a/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py +++ b/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py @@ -103,9 +103,6 @@ def _pre_load_state_dict_hook( class OffloadWrapper(ActivationWrapper): - def __init__(self, mod): - super().__init__(mod) - def forward(self, *args, **kwargs): with save_on_cpu(pin_memory=True): return self._checkpoint_wrapped_module(*args, **kwargs) diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py new file mode 100644 index 0000000000000..46267a686e86d --- /dev/null +++ b/torch/distributed/debug/__init__.py @@ -0,0 +1,82 @@ +import logging +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +logger: logging.Logger = logging.getLogger(__name__) + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + """ + Start the debug server stack on all workers. The frontend debug server is + only started on rank0 while the per rank worker servers are started on all + ranks. + + This server provides an HTTP frontend that allows for debugging slow and + deadlocked distributed jobs across all ranks simultaneously. This collects + data such as stack traces, FlightRecorder events, and performance profiles. + + WARNING: This is intended to only be used in trusted network environments. + The debug server is not designed to be secure and should not be exposed to + the public internet. See SECURITY.md for more details. + + WARNING: This is an experimental feature and may change at any time. + + Args: + port (int): The port to start the frontend debug server on. + worker_port (int): The port to start the worker server on. Defaults to 0, which + will cause the worker server to bind to an ephemeral port. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + logger.info("Starting debug server on port %d", port) + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._frontend import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + """ + Shutdown the debug server and stop the frontend debug server process. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + logger.info("Stopping debug server") + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py new file mode 100644 index 0000000000000..10dae4c2802cd --- /dev/null +++ b/torch/distributed/debug/_frontend.py @@ -0,0 +1,361 @@ +import json +import logging +import socket +import threading +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse + +import requests +from jinja2 import DictLoader, Environment + +from torch.distributed.debug._store import get_world_size, tcpstore_client + + +logger: logging.Logger = logging.getLogger(__name__) + + +def fetch_all( + endpoint: str, args: str = "" +) -> tuple[list[str], Iterator[requests.Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + with ThreadPoolExecutor(max_workers=10) as executor: + resps = executor.map(requests.post, addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, +} + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore + request_queue_size: int = 1024 + + +class HTTPRequestHandler(BaseHTTPRequestHandler): + frontend: "FrontendServer" + + def do_GET(self): + self.frontend._handle_request(self) + + def get_path(self) -> str: + return urlparse(self.path).path + + def get_query(self) -> dict[str, list[str]]: + return parse_qs(urlparse(self.path).query) + + def get_query_arg( + self, name: str, default: object = None, type: type = str + ) -> object: + query = self.get_query() + if name not in query: + return default + return type(query[name][0]) + + +class FrontendServer: + def __init__(self, port: int): + # Setup templates + loader = DictLoader(templates) + self._jinja_env = Environment(loader=loader, enable_async=True) + self._jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, + ) + + # Create routes + self._routes = { + "/": self._handle_index, + "/stacks": self._handle_stacks, + "/fr_trace": self._handle_fr_trace, + "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/profile": self._handle_profiler, + "/wait_counters": self._handle_wait_counters, + } + + # Create HTTP server + RequestHandlerClass = type( + "HTTPRequestHandler", + (HTTPRequestHandler,), + {"frontend": self}, + ) + + server_address = ("", port) + self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) + + self._thread = threading.Thread( + target=self._serve, + args=(), + daemon=True, + ) + self._thread.start() + + def _serve(self) -> None: + try: + self._server.serve_forever() + except Exception: + logger.exception("got exception in checkpoint server") + + def join(self) -> None: + self._thread.join() + + def _handle_request(self, req: HTTPRequestHandler) -> None: + path = req.get_path() + if path not in self._routes: + req.send_error(404, f"Handler not found: {path}") + return + + handler = self._routes[path] + try: + resp = handler(req) + except Exception as e: + logger.exception( + "Exception in checkpoint server when handling %s", + path, + ) + req.send_error(500, str(e)) + return + + req.send_response(200) + req.send_header("Content-type", "text/html") + req.end_headers() + req.wfile.write(resp) + + def _render_template(self, template: str, **kwargs: object) -> bytes: + return self._jinja_env.get_template(template).render(**kwargs).encode() + + def _handle_index(self, req: HTTPRequestHandler) -> bytes: + return self._render_template("index.html") + + def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_traceback") + return self._render_template( + "raw_resp.html", title="Stacks", addrs=addrs, resps=resps + ) + + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: + duration = req.get_query_arg("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return self._render_template("profile.html", addrs=addrs, resps=resps) + + def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("wait_counter_values") + return self._render_template( + "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps + ) + + +def main(port: int) -> None: + server = FrontendServer(port=port) + logger.info("Frontend server started on port %d", server._server.server_port) + server.join() diff --git a/torch/distributed/debug/_handlers.py b/torch/distributed/debug/_handlers.py new file mode 100644 index 0000000000000..ba951b7bda075 --- /dev/null +++ b/torch/distributed/debug/_handlers.py @@ -0,0 +1,22 @@ +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(open(f.name, "rb").read(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/torch/distributed/debug/_store.py b/torch/distributed/debug/_store.py new file mode 100644 index 0000000000000..70c6cd0f3dde1 --- /dev/null +++ b/torch/distributed/debug/_store.py @@ -0,0 +1,24 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client() -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + store = dist.PrefixStore("debug_server", store) + return store diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 794b755b1f64d..2bd7d24cd7d3f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -547,8 +547,12 @@ def foreach_reduce( op=reduce_scatter_op, ) else: - # For single GPU, just copy the input to output (no actual reduce-scatter needed) - reduce_output.copy_(reduce_scatter_input) + # For single GPU, just copy the input to output (no actual reduce-scatter needed), and + # account for a possible gradient_divide_factor. + if gradient_divide_factor is not None: + reduce_output.copy_(reduce_scatter_input / gradient_divide_factor) + else: + reduce_output.copy_(reduce_scatter_input) reduce_scatter_event = reduce_scatter_stream.record_event() post_reduce_stream = reduce_scatter_stream if all_reduce_group is not None: # HSDP or DDP/replicate @@ -721,20 +725,21 @@ def _get_gradient_divide_factors( if all_reduce_group is not None: data_parallel_size *= all_reduce_group.size() - if factor is None: - factor = float(data_parallel_size) - if not overflow_risk and not force_sum_reduction_for_comms: - if factor == data_parallel_size: + if factor is None: # Warning: NCCL ReduceOp.AVG may produce incorrect results with # world size 1. if data_parallel_size == 1: return None, None, ReduceOp.SUM, ReduceOp.SUM return None, None, ReduceOp.AVG, ReduceOp.AVG + if reduce_scatter_group is not None and factor == reduce_scatter_group.size(): + reduce_scatter_op = ReduceOp.AVG else: reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) - return None, None, reduce_scatter_op, ReduceOp.SUM + return None, None, reduce_scatter_op, ReduceOp.SUM + if factor is None: + factor = float(data_parallel_size) pre_factor: Optional[float] if overflow_risk: # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index abc007a8166db..7bdf3c65e4e8f 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -2272,9 +2272,7 @@ def _perform_action(action: _Action) -> None: time_step, action, ) - # TODO(whc) what is the best practice for printing a multiline log? - # logger will split it into multiple log lines, but this makes it hard to read (too wide) - print( + logger.error( _format_pipeline_order( self.pipeline_order_with_comms, # type: ignore[arg-type] error_step_number=time_step, diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index f21ef72533658..dabf9f6f194ce 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -349,6 +349,9 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] # We just need to have an implementation here; the __torch_dispatch__ machinery # calls into a specific C++ fast path that doesn't call here. + # See #167051 for details + # python_arg_parser.cpp: dispatch_on_subclass() + # -> python_variable.cpp: dispatchDTensorOp() raise NotImplementedError( "DTensor.__torch_dispatch__ should not actually get called" ) @@ -818,6 +821,11 @@ def distribute_tensor( local_tensor = Replicate._make_replicate_tensor( local_tensor, device_mesh, idx, src_data_rank ) + elif isinstance(placement, Partial): + local_tensor = Replicate._make_replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + local_tensor = placement._partition_value(local_tensor, device_mesh, idx) else: raise RuntimeError( f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" @@ -1063,7 +1071,7 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] # get local tensor shape local_shape, _ = compute_local_shape_and_global_offset( - size, device_mesh, placements + size, device_mesh, placements, skip_offset=True ) # initialize the local tensor diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index dff426a6d5e5a..90f32efafd395 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -227,6 +227,7 @@ def check_tensor_meta( return None +# TODO: autoparallel depends on this function, we will keep it until we update autoparallel redistribute_cost def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: assert spec.tensor_meta is not None, "spec should have tensor meta defined!" return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) @@ -338,39 +339,61 @@ def redistribute_cost( mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) cost = 0.0 - comm_bytes_gb = ( - spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 - ) # Transformation that considered for redistribute cost: # 1. allgather 2. alltoall # 3. allreduce 4. reduce_scatter - for i, (current, target) in enumerate( - zip(current_spec.placements, target_spec.placements) - ): + from torch.distributed._functional_collectives import _are_we_tracing + from torch.distributed.tensor._redistribute import ( + _gen_transform_infos, + _gen_transform_infos_non_cached, + ) + + # No redistribution needed when placements are already identical. + # This also prevents potential failures in _gen_transform_infos for certain configurations + # (e.g., sub-meshes) where finding a transform path between identical states may error out. + # TODO(zpcore): test placements with _StridedShard. + if current_spec.placements == target_spec.placements: + return cost + if _are_we_tracing(): + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + for transform_info in transform_infos: + assert current_spec.tensor_meta is not None, ( + "spec should have tensor meta defined!" + ) + comm_bytes_gb = ( + current_spec.tensor_meta.dtype.itemsize + * math.prod(transform_info.logical_shape) + / 1024 + / 1024 + / 1024 + ) + current = transform_info.src_dst_placements[0] + target = transform_info.src_dst_placements[1] if current == target: continue - - num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + mesh_dim = transform_info.mesh_dim + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] if current.is_shard() and target.is_replicate(): - # allgather gives larger comm bytes - comm_bytes_gb *= num_devices_on_mesh_dim # add up allgather comm cost - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_shard() and target.is_shard(): - # should be alltoall comm, since we haven't implement it yet, add penalty + # should be alltoall comm, since we haven't implement it yet, add 1.0 as penalty # to favor allgather instead - cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 + # TODO: add alltoall_cost + comm_bytes_gb /= num_devices_on_mesh_dim + cost += allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) + 1.0 elif current.is_partial() and target.is_replicate(): # add up allreduce comm cost - cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + cost += allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif current.is_partial() and target.is_shard(): # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) # after reduce_scatter the comm bytes for further collectives halved. comm_bytes_gb /= num_devices_on_mesh_dim elif current.is_shard() and target.is_partial(): # ban shard -> partial as it does not make sense to perform # this redistribute return float("inf") - return cost diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index cbd817a8bde37..aaa5d25c79ba7 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -10,6 +10,7 @@ import torch.distributed.tensor._api as dtensor import torch.distributed.tensor._random as random from torch._library.utils import fill_defaults +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import ( @@ -135,7 +136,9 @@ def __init__(self) -> None: self._random_ops = { aten.native_dropout.default, aten.normal_.default, + aten.rand.default, aten.rand_like.default, + aten.randn.default, aten.randn_like.default, aten.randint_like.default, aten.randint_like.low_dtype, @@ -152,6 +155,17 @@ def __init__(self) -> None: aten.as_strided.default: as_strided_handler, } + # ******************************************************************************************** + # def dispatch(...) + # + # NOTE: this class no longer contains the top-level dispatch entrypoint! + # See #167051 for details + # + # The entrypoint has been moved to C++, and it handles common cases and then calls back into + # OpDispatcher python to handle corner cases. + # See dispatchDTensorOp() defined in python_variable.cpp and called from python_arg_parser.cpp + # ******************************************************************************************** + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) # as implicitly replicated or we throw error to user. # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave @@ -164,17 +178,39 @@ def _allow_implicit_replication(self) -> bool: def _allow_implicit_replication(self, value: bool) -> None: return torch._C._set_dtensor_allow_implicit_replication(value) - def _propagate_op_sharding_non_cached_dispatch_slow_path( + def _propagate_op_sharding_dispatch_slow_path( self, op_call: torch._ops.OpOverload, args: tuple[object, ...], kwargs: dict[str, object], op_info: OpInfo, + # The logic here is a bit messy. There are several reasons why the + # C++ fastpath may have bailed out. If we just cache missed, we will + # come here because we need to actually calculate the real thing. + # There's no need to have a SECOND Python cache lookup; the C++ native + # cache completely subsumes it. But sometimes, we will have failed + # to compute the cache key in C++ entirely. In this case, we DO need + # to do a cache lookup in Python, as the missing cache key in C++ + # means we don't have access to it all. Furthermore, without duping + # this function, we need to do the try_cache test inside of the + # try-except block so that either case hits the inference mode / + # exception rewrapping case. + # + # This should be cleaned up. First, ensuring the C++ codepath can + # always compute a key will be a big help. Second, we should properly + # fastpath inference mode composite implicit autograd so that you + # don't have to throw an exception even in "fastpath". + try_cache: bool, ) -> object: try: - return self.sharding_propagator.propagate_op_sharding_non_cached( - op_info.schema - ) + # We have basically inlined propagate() here, but WITHOUT the + # output_sharding assignment + if try_cache and not _are_we_tracing(): + return self.sharding_propagator.propagate_op_sharding(op_info.schema) + else: + return self.sharding_propagator.propagate_op_sharding_non_cached( + op_info.schema + ) except NotImplementedError: if torch._C._dispatch_has_kernel_for_dispatch_key( op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 95e9509cdbcd6..283eaf4a06db8 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -361,6 +361,16 @@ def args_strategy(self) -> tuple[OpStrategy, ...]: ) return tuple(item for item in args if isinstance(item, OpStrategy)) + @property + def kwargs_strategy(self) -> tuple[OpStrategy, ...]: + # returns OpStrategy items from kwargs_schema. + kwargs_vals = ( + tree_leaves(self.kwargs_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.kwargs_schema.values() + ) + return tuple(item for item in kwargs_vals if isinstance(item, OpStrategy)) + def __repr__(self) -> str: args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) return ( diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 1e7ff648f7fbd..2d4a311b4bedd 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -168,7 +168,10 @@ def merge_sharding(dim: str, a: int, b: int) -> int: assert input_spec.tensor_meta is not None global_shape = input_spec.tensor_meta.shape local_shape, _ = compute_local_shape_and_global_offset( - global_shape, input_spec.mesh, input_spec.placements + global_shape, + input_spec.mesh, + input_spec.placements, + skip_offset=True, ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) # pyrefly: ignore [bad-argument-type] diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index df9b81ac5df6e..1f456d505c127 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -4,7 +4,7 @@ import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._op_schema import OpSchema, OutputSharding -from torch.distributed.tensor._ops.utils import register_prop_rule +from torch.distributed.tensor._ops.registration import register_prop_rule aten = torch.ops.aten diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 41272c0f31a92..b7c4abf353be5 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -10,10 +10,8 @@ PlacementList, StrategyType, ) -from torch.distributed.tensor._ops.utils import ( - expand_to_full_mesh_op_strategy, - register_op_strategy, -) +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy from torch.distributed.tensor.placement_types import ( MaskPartial, Partial, diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 545895c83b6eb..ac0180f07d05e 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -17,6 +17,7 @@ RuntimeSchemaInfo, TupleStrategy, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( as_list, expand_to_full_mesh_op_strategy, @@ -25,7 +26,6 @@ is_tensor_evenly_shardable_on_dim, normalize_dim, normalize_dims, - register_op_strategy, ) from torch.distributed.tensor._utils import normalize_to_torch_size from torch.distributed.tensor.placement_types import ( @@ -163,6 +163,16 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return 1 + hash(self.norm_type) + def __repr__(self) -> str: + """ + machine readable representation of the _NormPartial placement + """ + return f"_NormPartial(reduce_op={self.reduce_op}, norm_type={self.norm_type})" + + def __str__(self) -> str: + """human readable representation of the _NormPartial placement""" + return f"_NormP({self.reduce_op}, {self.norm_type})" + def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]: if dims_arg is None: @@ -405,10 +415,15 @@ def cumsum_strategy(op_schema: OpSchema) -> OpStrategy: @register_op_strategy( - [aten.var.correction, aten.var.correction_out], + [ + aten.std.correction, + aten.std.correction_out, + aten.var.correction, + aten.var.correction_out, + ], schema_info=RuntimeSchemaInfo(1, ["keepdim"]), ) -def var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: +def std_var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: args_schema = op_schema.args_schema input_strategy = args_schema[0] if not isinstance(input_strategy, OpStrategy): diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 49152a1bee13a..ecd7938d75e2e 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -15,6 +15,7 @@ RuntimeSchemaInfo, ) from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -22,7 +23,6 @@ is_tensor_shardable, map_placements_after_broadcast, prod, - register_op_strategy, ) from torch.distributed.tensor._utils import ( compute_local_shape_and_global_offset, @@ -1090,7 +1090,7 @@ def local_meta(spec: OpSpec, placements: tuple[Placement, ...]) -> TensorMeta: meta: TensorMeta = spec.output_specs.tensor_meta local_stride = compute_local_stride(meta.stride, mesh, placements) local_shape, _ = compute_local_shape_and_global_offset( - meta.shape, mesh, placements + meta.shape, mesh, placements, skip_offset=True ) return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 53b759e993c0d..011a1ec667fb4 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -12,12 +12,12 @@ StrategyType, TupleStrategy, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, infer_broadcast_dims_map, map_placements_after_broadcast, normalize_dim, - register_op_strategy, ) from torch.distributed.tensor.placement_types import ( Partial, diff --git a/torch/distributed/tensor/_ops/_random_ops.py b/torch/distributed/tensor/_ops/_random_ops.py index 9db9b85e58d2d..dd4cf8fec226a 100644 --- a/torch/distributed/tensor/_ops/_random_ops.py +++ b/torch/distributed/tensor/_ops/_random_ops.py @@ -6,7 +6,8 @@ OpStrategy, StrategyType, ) -from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy +from torch.distributed.tensor._ops.registration import register_op_strategy +from torch.distributed.tensor._ops.utils import is_tensor_partial aten = torch.ops.aten diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index fe20e41f59285..cb336486785af 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -18,6 +18,10 @@ ) from torch.distributed.tensor._ops._common_rules import pointwise_rule from torch.distributed.tensor._ops._embedding_ops import MaskPartial +from torch.distributed.tensor._ops.registration import ( + register_op_strategy, + register_prop_rule, +) from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, generate_redistribute_costs, @@ -25,8 +29,6 @@ is_tensor_evenly_shardable, is_tensor_partial, normalize_dim, - register_op_strategy, - register_prop_rule, shift_shard_dims_after_insert, shift_shard_dims_after_remove, ) diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index 2d9e33402c607..6c8954729b976 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -15,12 +15,12 @@ RuntimeSchemaInfo, StrategyType, ) +from torch.distributed.tensor._ops.registration import register_op_strategy from torch.distributed.tensor._ops.utils import ( generate_redistribute_costs, normalize_dim, normalize_dims, prod, - register_op_strategy, ) from torch.distributed.tensor.placement_types import ( _StridedShard, diff --git a/torch/distributed/tensor/_ops/registration.py b/torch/distributed/tensor/_ops/registration.py new file mode 100644 index 0000000000000..3864d8971069e --- /dev/null +++ b/torch/distributed/tensor/_ops/registration.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Callable +from typing import Optional, TypeAlias, TypeVar, Union + +import torch +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, + StrategyType, +) + + +# convenient wrapper to register sharding propagation rules +def register_prop_rule( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[ + [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] +]: + def wrapper( + impl: Callable[[OpSchema], OutputSharding], + ) -> Callable[[OpSchema], OutputSharding]: + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +# Note: +# using TypeVar here allows the registration decorator to preserve the specific type info of the wrapped strategy, +# while hardcoding the typing on the wrapper (e.g. Callable[[OpSchema], StrategyType]) would mean mypy would treat +# the return value of the wrapped strategy as always being a `StrategyType` even if it were a derived class like +# MyStrategyType(StrategyType). +_OpSchemaT = TypeVar("_OpSchemaT", bound=OpSchema) +_StrategyTypeT = TypeVar("_StrategyTypeT", bound=StrategyType) +_ShardingStrategyFunc: TypeAlias = Callable[[_OpSchemaT], _StrategyTypeT] + + +def register_op_strategy( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[[_ShardingStrategyFunc], _ShardingStrategyFunc]: + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl: _ShardingStrategyFunc) -> _ShardingStrategyFunc: + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 9a4ce12ed82fa..f09a888734807 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -4,21 +4,17 @@ import itertools import operator from collections.abc import Callable, Iterable, Sequence -from typing import cast, Optional, TypeVar, Union -from typing_extensions import ParamSpec +from typing import cast, Optional, Union import torch from torch._prims_common import DimsSequenceType, DimsType -from torch.distributed.tensor._api import DTensor from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import ( OpSchema, OpSpec, OpStrategy, - OutputSharding, PlacementList, - RuntimeSchemaInfo, StrategyType, ) from torch.distributed.tensor.device_mesh import DeviceMesh @@ -30,79 +26,14 @@ ) -_T = TypeVar("_T") -_P = ParamSpec("_P") - - -# convenient wrapper to register sharding propagation rules -def register_prop_rule( - op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], - schema_info: Optional[RuntimeSchemaInfo] = None, -) -> Callable[ - [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] -]: - def wrapper( - impl: Callable[[OpSchema], OutputSharding], - ) -> Callable[[OpSchema], OutputSharding]: - overloads = op if isinstance(op, list) else [op] - for overload in overloads: - DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( - overload, impl, schema_info - ) - return impl - - return wrapper - - -def register_op_strategy( - op, schema_info=None -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - # pyre-fixme[2]: Parameter must be annotated. - - # For every ATen op that accepts any args in this list, - # the arg itself can impact the strides (and potentially the sharding strategy) - # of the output tensor. - # thus, we will detect ATen schemas with any of these args and ensure - # that they get specialized here. - arg_names_that_require_specializing_cache_strategy = [ - "memory_format", - ] - - def wrapper(impl): - if isinstance(op, list): - overloads = op - else: - overloads = [op] - - for overload in overloads: - curr_schema_info = None - if schema_info is None: - specialized_args = [ - a.name - for a in overload._schema.arguments - if a.name in arg_names_that_require_specializing_cache_strategy - ] - if any(specialized_args): - curr_schema_info = RuntimeSchemaInfo( - static_kwargkey=specialized_args - ) - else: - curr_schema_info = schema_info - DTensor._op_dispatcher.sharding_propagator.register_op_strategy( - overload, impl, curr_schema_info - ) - return impl - - return wrapper - - def replicate_op_strategy(op_schema: OpSchema) -> StrategyType: """ Fallback strategy all use Replication() """ - inputs_strategy = op_schema.args_strategy - # TODO(zpcore): handle kwarg_inputs_strategy - # kwarg_inputs_strategy = op_schema.kwargs_schema + args_strategy = op_schema.args_strategy + kwargs_strategy = op_schema.kwargs_strategy + inputs_strategy = args_strategy + kwargs_strategy + output_type = [str(ret.type) for ret in op_schema.op._schema.returns] output_len = output_type.count("Tensor") # TODO(zpcore): Confirm if view op can be handle properly or not. Prevent @@ -159,7 +90,10 @@ def prod(xs: Iterable[int]) -> int: def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: - """Check if the shape is shardable according to the spec.""" + """Check if the spec matches these criteria: + * any Shard placements in spec refer to valid tensor dims + * no empty local tensors (uneven sharding OK, as long as last rank has >0 size) + """ # number of shards in each tensor dimension shards_map = [1] * len(shape) for i, placement in enumerate(spec.placements): @@ -225,6 +159,9 @@ def infer_broadcast_dims_map( ) -> list[int]: # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim # this is aligned with the broadcast semantics + # e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4], + # broadcast_dims_map will be [-1, 0, 1, 2] + # meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input common_ndim = len(common_shape) input_ndim = len(input_shape) broadcast_dims_map = [-1] * common_ndim @@ -345,8 +282,15 @@ def expand_to_full_mesh_op_strategy( s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) ] - input_args_strategy = op_schema.args_strategy - assert len(input_specs) == len(input_args_strategy) + args_strategy = op_schema.args_strategy + kwargs_strategy = op_schema.kwargs_strategy + input_args_strategy = args_strategy + kwargs_strategy + + if len(input_specs) != len(input_args_strategy): + raise AssertionError( + f"input_specs({len(input_specs)}) != strategies({len(input_args_strategy)}: " + f"{len(args_strategy)} args + {len(kwargs_strategy)} kwargs)" + ) self_spec = input_args_strategy[0].strategies[0].output_spec if inplace_op and self_spec.placements != input_specs[0].placements: diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index f8325c83d55e4..d117df2d67e2e 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -101,6 +101,9 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing # as manual seed on torch. + # + # torch.manual_seed will handle LocalTensor mode correctly by + # iterating through all ranks if seed is a LocalIntNode. torch.manual_seed(seed) @@ -132,7 +135,7 @@ def offset(self, offset: int) -> None: @property def seed(self) -> int: - return int(self._state[:8].view(dtype=torch.int64).item()) + return int(self._state[:8].view(dtype=torch.uint64).item()) @seed.setter def seed(self, seed: int) -> None: @@ -239,6 +242,16 @@ def _set_device_state(self, state: torch.Tensor): def _distribute_region( self, spec: DTensorSpec, generator: Optional[torch.Generator] = None ): + from torch.distributed._local_tensor import maybe_enable_local_tracker + + if local_tracker_context := maybe_enable_local_tracker( + self._device.type, self.distribute_region_enabled, spec, generator + ): + with local_tracker_context: + yield + return + + # regular (non-LocalTensor) mode if generator is not None: # This is a little hacky, but for any user-passed generator, we store its state under a unique key, # not because we need to keep a copy of it but because its the easiest way to make it work with the diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index a407ba6ca91df..84e58c4df169c 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -32,6 +32,72 @@ logger = logging.getLogger(__name__) +# Global configuration flag to control the redistribution planning strategy. +# When True, forces the graph-based algorithm using Dijkstra's shortest path. +# When False, prefers the greedy algorithm for faster planning. Uses the graph-based algorithm +# only when necessary to support strided-shard redistribution +_FORCE_MIN_COST_REDISTRIBUTION_PLAN: Optional[bool] = None + + +@contextlib.contextmanager +def use_min_cost_redistribution_plan(enabled: bool = True): + """ + Context manager to control the redistribution planning strategy for DTensor operations. + + This context manager allows you to choose between two algorithms for computing the + sequence of collective operations needed to redistribute a DTensor from one placement + to another: + + - **Graph-based**: Uses Dijkstra's algorithm to find the minimum-cost path + through all possible placement transformations. This approach considers the global + cost of all collective operations and finds the optimal sequence. Best for complex + redistribution patterns where reducing communication cost and memory overhead is critical. + + - **Greedy**: Uses a heuristic approach that makes locally optimal choices + at each step. This is faster to compute but may not produce the globally optimal + transformation sequence. Best for simple redistribution patterns or when planning + speed is more important than optimal communication. + + **Default Behavior (without this context manager):** + + When this context manager is NOT used, the algorithm selection follows this priority: + + 1. **Non-default shard orders** + → Always use graph-based algorithm (required for correctness) + + 2. **Explicit `use_graph_based_transform` parameter** to `_gen_transform_infos_non_cached` + → Use the specified algorithm (True = graph-based, False = greedy) + + 3. **No explicit parameter** (default case) + → Use greedy algorithm for faster planning + + **Behavior with this context manager:** + + This context manager overrides the default selection by setting the global flag + `_FORCE_MIN_COST_REDISTRIBUTION_PLAN`, which takes precedence over the explicit + `use_graph_based_transform` parameter (but not over non-default shard order requirements). + + **Cache Considerations:** + + The redistribution planner caches transform info for performance via the `@cache` + decorator on `_gen_transform_infos`. If you need to change the algorithm selection + for the same input specs, clear the cache using `_gen_transform_infos.cache_clear()` + to ensure the new setting takes effect and doesn't reuse cached results from a + previous run. + + Args: + enabled (bool): If True, forces the use of the graph-based algorithm. + If False, forces the use of the greedy algorithm. + Default: True + """ + global _FORCE_MIN_COST_REDISTRIBUTION_PLAN + old_value = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = enabled + try: + yield + finally: + _FORCE_MIN_COST_REDISTRIBUTION_PLAN = old_value + class _TransformInfo(NamedTuple): mesh_dim: int @@ -648,22 +714,29 @@ def _gen_transform_infos_non_cached( dst_spec: DTensorSpec, use_graph_based_transform: Optional[bool] = None, ) -> list[_TransformInfo]: - transform_infos: list[_TransformInfo] = [] device_mesh = src_spec.device_mesh src_shard_order = src_spec.shard_order dst_shard_order = dst_spec.shard_order # DTensorSpec should automatically generate shard_order, and it can be () if # no shard. assert src_shard_order is not None and dst_shard_order is not None - if use_graph_based_transform is None: - if all( - DTensorSpec.is_default_device_order(order) - for order in (src_shard_order, dst_shard_order) - ): - use_graph_based_transform = False - else: - # switch to graph search algorithm if the device order is not the default - use_graph_based_transform = True + + # Determine which transform strategy to use: + # 1. Non-standard device order → always use graph-based + # 2. Global flag or explicit parameter True → use graph-based + # 3. Otherwise → use greedy + has_non_default_order = not all( + DTensorSpec.is_default_device_order(order) + for order in (src_shard_order, dst_shard_order) + ) + + if has_non_default_order is True: + use_graph_based_transform = True + elif _FORCE_MIN_COST_REDISTRIBUTION_PLAN is not None: + use_graph_based_transform = _FORCE_MIN_COST_REDISTRIBUTION_PLAN + elif use_graph_based_transform is None: + use_graph_based_transform = False + drp = get_redistribute_planner(device_mesh, len(src_spec.shape)) if use_graph_based_transform: transform_infos = drp.generate_graph_based_transform_infos( diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index ede7515efd102..2db44f387e4eb 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -345,6 +345,10 @@ def spec_to_strategy(spec: object) -> object: ) def propagate(self, op_info: OpInfo) -> None: + # NB: The logic here is duplicated in _propagate_op_sharding_dispatch_slow_path. + # Ideally, this function would be deleted, but there are a handful of + # one off call sites here that aren't cleaned up. + # We cannot use an lru cache if we know that inputs will have dynamic shapes, # because SymInts are not hashable. # This is generally ok because this only happens during tracing in torch.compile, @@ -656,7 +660,7 @@ def _adjust_shape_and_stride_args( # adjust shape to be the same as that of the _local_tensor # of the DTensor input arg at index 0, which is inferred expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( - out_tensor_meta.shape, spec.mesh, spec.placements + out_tensor_meta.shape, spec.mesh, spec.placements, skip_offset=True ) # adjust the stride arg for aten.new_empty_strided.default diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 74ad2aaa80434..d7ee355500528 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,5 +1,4 @@ import threading -from collections import defaultdict from collections.abc import Sequence from typing import cast, Optional @@ -7,6 +6,7 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._api as dtensor from torch._prims_common import ShapeType +from torch.distributed._local_tensor import maybe_run_for_local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._collective_utils import redistribute_cost from torch.distributed.tensor._dtensor_spec import DTensorSpec @@ -17,7 +17,6 @@ Replicate, Shard, ) -from torch.utils._typing_utils import not_none class ExplicitRedistributionContext: @@ -56,61 +55,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): ExplicitRedistributionContext._local._active = self._prev -def _explicit_order_placements( - mesh_shape: ShapeType, placements: Sequence[Placement] -) -> Sequence[tuple[int, Placement]]: - """ - Replace Strided Shards with regular shards in an adjusted order. - - Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. - - ex. - [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> - [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] - - """ - if not len(placements) == len(mesh_shape): - raise RuntimeError( - "Expected one placement per mesh dim, " - f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." - ) - ordered = [] - deferred_strided_placements = defaultdict(list) - strided_part_ended_for_dim = set() - for mesh_dim, p in enumerate(placements): - if isinstance(p, _StridedShard): - # validate the stride is the correct multiple of the meshdim and the earlier shard - deferred_strided_placements[p.dim].append((mesh_dim, p)) - - else: - ordered.append((mesh_dim, p)) - if isinstance(p, Shard): - if p.dim in strided_part_ended_for_dim: - raise NotImplementedError( - f"Strided sharding does not allow Shard() to appear after " - f"the strided part has ended. {p} at mesh dim {mesh_dim} in " - f"{placements} violates this assumption." - ) - - if p.dim in deferred_strided_placements: - strided_part_ended_for_dim.add(p.dim) - strided_placements = deferred_strided_placements.pop(p.dim) - aggregate_size = mesh_shape[mesh_dim] - while len(strided_placements) > 0: - strided_mesh_dim, strided = strided_placements.pop() - if not strided.split_factor == aggregate_size: - raise RuntimeError( - f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" - f" == aggregate mesh size ({aggregate_size})" - ) - aggregate_size *= mesh_shape[strided_mesh_dim] - ordered.append((strided_mesh_dim, Shard(p.dim))) - - return ordered - - def compute_local_shape_and_global_offset( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Compute the local tensor shape and the global offsets into the original tensor @@ -143,24 +92,55 @@ def compute_local_shape_and_global_offset( global_shape (ShapeType): The global shape of the DTensor. mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. Return: local_shape: the shape of the DTensor's _local_tensor on the current rank. global_offset: a tuple of offsets for each dimension of the global tensor shape, - identifying how this shard fits into the global tensor in each dimension. + identifying how this shard fits into the global tensor in each dimension. If + skip_offset is True, this will be an empty tuple. """ return _compute_local_shape_and_global_offset( - global_shape, mesh.shape, mesh.get_coordinate(), placements + global_shape, mesh.shape, mesh.get_coordinate(), placements, skip_offset ) +@maybe_run_for_local_tensor +def _compute_offsets( + placement, + shard_offsets: int, + shard_size: int, + zero_global_offset: int, + previous_offsets, +) -> torch.Tensor: + if shard_size == 0: + return torch.arange(zero_global_offset, zero_global_offset + 1) + if isinstance(placement, Shard) and not isinstance(placement, _StridedShard): + index = torch.arange(shard_offsets, shard_offsets + shard_size) + else: + assert isinstance(shard_offsets, list) + index = torch.tensor(shard_offsets) + if previous_offsets is None: + return index + else: + return previous_offsets[index] + + +@maybe_run_for_local_tensor +def _get_first_offset(offsets: torch.Tensor) -> int: + return int(offsets[0]) + + # accept 'plain data types' to enable simpler unit testing without creating device mesh def _compute_local_shape_and_global_offset( global_shape: ShapeType, mesh_shape: ShapeType, my_coordinate: Optional[list[int]], placements: Sequence[Placement], + skip_offset: bool = False, ) -> tuple[tuple[int, ...], tuple[int, ...]]: """ Suppose you have a full tensor with size global_shape, and you have sharded @@ -176,85 +156,72 @@ def _compute_local_shape_and_global_offset( This function is fairly simple if your tensor is evenly sharded; the complication is around uneven splits. There is also some complication for handling StridedShard, which changes the order you should apply sharding. + + Args: + global_shape (ShapeType): The global shape of the tensor. + mesh_shape (ShapeType): The shape of the device mesh. + my_coordinate (Optional[list[int]]): The coordinate of the current rank in the device mesh. + placements (Sequence[Placement]): The placements of the DTensor. + skip_offset (bool): If True, skip computing the global offsets and return an empty + tuple for global_offset. This can improve performance when only the local shape + is needed. Defaults to False. + + Returns: + tuple: A tuple containing: + - local_shape (tuple[int, ...]): The shape of the local shard on the current rank. + - global_offset (tuple[int, ...]): The offsets for each dimension identifying where + this shard begins in the global tensor. If skip_offset is True, this will be an + empty tuple. """ + empty_offset = () if my_coordinate is None: # if rank not in the mesh, return empty offset - return ((0,), ()) - - # StridedShard implies a non-standard order to apply shards; get the - # correct order to start applying splits - ordered_placements = _explicit_order_placements(mesh_shape, placements) + return ((0,), empty_offset) local_shape = list(global_shape) - # We'll compute the data for where the shard begins on a per-dim basis. - # However, a single dim can be sharded multiple times, so we will end up - # doing a Sum(size*stride) like computation to determine the location of our - # shard for each of the shardings on that dim. - global_offset = [0] * len(global_shape) - - for mesh_dim, placement in ordered_placements: + # Perform shard from left to right. For example, + # global tensor: [0, 1, 2, 3, 4, 5, 6, 7] + # placements: S(0), SS(0, split_factor=2) + # mesh_shape: (2, 2) + # After S(0), shard_dim_to_global_offsets are + # {0: [0, 1, 2, 3]} on my_coordinate [0, 0] [0, 1] + # {0: [4, 5, 6, 7]} on my_coordinate [1, 0] [1, 1] + # After SS(0, split_factor=2), shard_dim_to_global_offsets are + # {0: [0, 2]} on my_coordinate [0, 0] + # {0: [1, 3]} on my_coordinate [0, 1] + # {0: [4, 6]} on my_coordinate [1, 0] + # {0: [5, 7]} on my_coordinate [1, 1] + shard_dim_to_global_offsets = {} + for mesh_dim, placement in enumerate(placements): mesh_dim_size = mesh_shape[mesh_dim] - if isinstance(placement, Shard): - shard_dim = placement.dim - assert shard_dim < len(local_shape), ( - f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - ) - shard_size, shard_offset = placement._local_shard_size_and_offset( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[mesh_dim], - ) - - local_shape[shard_dim] = shard_size - - shard_global_offset = global_offset[shard_dim] + not_none(shard_offset) - - zero_global_offset = global_shape[shard_dim] - if isinstance(shard_global_offset, torch.SymInt) and not isinstance( - zero_global_offset, torch.SymInt - ): - zero_global_offset = torch.SymInt(zero_global_offset) - - global_offset[shard_dim] = torch.sym_ite( - shard_size == 0, - # Special case to fill in a standardized non-garbage value for - # the global_offset of zero-sized shards. This value is out - # of bounds of the tensor, so it won't conflict with any real - # offsets. DCP may rely on this value to de-duplicate shards. - # Note that you can end up with zero-size shards that are - # still otherwise in bounds for the tensor (TODO: give an - # example). - zero_global_offset, - # As we successively shard the same dimension, we keep - # advancing our pointer beyond our original offset until we - # get to the final chunk start. - shard_global_offset, - ) - - # NOTE: the offset compute relies on the local shard index and it has no - # problem when strided sharding is not present. To correctly compute, we assume - # that the ``_StridedShard.split_factor`` field encodes how many partitions - # each local tensor will be further split into when sharding on higher mesh - # dimensions. However, this number is only correct if the DTensor is not - # sharded after the strided sharding completes. For example, - # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements - # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on - # device mesh dim-2, and last on mesh dim-1. We define the - # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding - # part because strided sharding happens on mesh dim-1 and it was caused by - # the fact that sharding on dim-2 occurred ahead. In this case, there's no - # further sharding after this strided sharding part and ``split_factor`` - # correctly encodes the number. Another example is - # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's - # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh - # dim-2. This violates our assumption that no further sharding shall occur - # after the strided sharding part and ``split_factor`` won't correctly - # encode the number of further split. So far, the only case where _StridedShard - # placement would appear is FSDP2 + TP on 2D mesh and the above case could only - # happen on mesh of 3 or more dimensions. - # TODO: change this function to correctly address this. - # TODO: this logic can be applied to contiguous sharding as well + if not isinstance(placement, (Shard, _StridedShard)): + continue + shard_dim = placement.dim + zero_global_offset = global_shape[shard_dim] + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + shard_size, shard_offsets = placement._local_shard_size_and_offset( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + local_shape[shard_dim] = shard_size + if skip_offset: + continue + shard_dim_to_global_offsets[shard_dim] = _compute_offsets( + placement, + shard_offsets, + shard_size, + zero_global_offset, + shard_dim_to_global_offsets.get(shard_dim), + ) + if skip_offset: + return tuple(local_shape), empty_offset + global_offset = [0] * len(global_shape) + for shard_dim, global_offsets in shard_dim_to_global_offsets.items(): + global_offset[shard_dim] = _get_first_offset(global_offsets) return tuple(local_shape), tuple(global_offset) diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 2cf6e572dcdf7..65da0a7b1823b 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -684,12 +684,13 @@ def _to_replicate_tensor( def _local_shard_size(sharded_indices: list[torch.Tensor], rank: int) -> int: return len(sharded_indices[rank]) - def _local_shard_size_and_offset( + # delete pyre-ignore once separating _StridedShard from Shard + def _local_shard_size_and_offset( # pyre-ignore[bad-override] self, curr_local_size: int, num_chunks: int, rank: int, - ) -> tuple[int, Optional[int]]: + ) -> tuple[int, list[int]]: # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] @@ -707,9 +708,9 @@ def _local_shard_size_and_offset( sharded_indices = [shard.view(-1) for shard in sharded_indices] local_shard_size = _StridedShard._local_shard_size(sharded_indices, rank) + offsets = sharded_indices[rank].tolist() - # offsets from _StridedShard is never used - return local_shard_size, None + return local_shard_size, offsets class Replicate(torch._C._distributed.Replicate): @@ -816,14 +817,18 @@ def _partition_value( # Partial placement contract #3: # _partition_value: partition the value of a replicated tensor on the mesh dimension - # _partition_value is the conjugate operation of _reduce_value - # - i.e. _partition_value on a sum reduce op is just a division operation - # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation - # TODO: if the reduce_op is min/max, etc. the _partition_value should be a - # different operation - assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + # _partition_value is the conjugate operation of _reduce_value, e.g. + # - _partition_value on a sum reduce op is just a division operation + # - _reduce_value on a sum reduce op would just be a sum(allreduce) operation num_chunks = mesh.size(mesh_dim=mesh_dim) - return tensor / num_chunks + if self.reduce_op == "sum": + return tensor / num_chunks + elif self.reduce_op in ("avg", "min", "max"): + return tensor + else: + raise ValueError( + f"Replicate to Partial({self.reduce_op}) conversion is not supported." + ) def __hash__(self) -> int: return 1 + hash(self.reduce_op) @@ -838,7 +843,7 @@ def __str__(self) -> str: """ human readable representation of the Partial placement """ - return "P" + return f"P({self.reduce_op})" # We keep the old _Partial name for a while for BC reason @@ -982,10 +987,10 @@ def __repr__(self) -> str: """ machine readable representation of the MaskPartial placement """ - return f"MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + return f"MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" def __str__(self) -> str: """ human readable representation of the MaskPartial placement """ - return "MaskP" + return f"MaskP({self.reduce_op}, {self.offset_shape}, {self.offset_dim})" diff --git a/torch/export/_trace.py b/torch/export/_trace.py index b38986ab070f7..856f23f68b19e 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -851,7 +851,9 @@ def use_legacy_dynamo_graph_capture() -> bool: f, constraints=constraints, dynamic_shapes=dynamic_shapes ) else: - dynamo_graph_capture = dynamo_graph_capture_for_export(f) + dynamo_graph_capture = torch._dynamo.config.patch( + replay_side_effects=False + )(dynamo_graph_capture_for_export(f)) # We can't serialize entire fake mode yet, so this is to make sure # things like copy.deepcopy(ep.graph_module) not crash. # see test_export.py::test_custom_tag_metadata_re_export diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 16b975f6b069a..96b44b0aebd4d 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -959,7 +959,7 @@ def _bitwise_xor(a, b): reflectable_magic_methods = { - "add": _optimized_add, + "add": operator.add, "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, @@ -1398,7 +1398,7 @@ def binary_magic_impl(self, other): out = PythonMod(self.expr, other.expr) elif method == "add": # see Note [optimized_summation] - (optimized_summation, out) = func( + (optimized_summation, out) = _optimized_add( self.expr, other.expr, self._optimized_summation, diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 598ca377f794b..7e64efbb8b73c 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -179,6 +179,21 @@ toString << toUnderlying +# torch/headeronly/core/Layout.h +Layout +kStrided +kSparse +kSparseCsr +kSparseCsc +kSparseBsr +kSparseBsc +kMkldnn +kJagged + +# torch/headeronly/core/MemoryFormat.h +MemoryFormat +get_contiguous_memory_format + # torch/headeronly/core/Dispatch_v2.h THO_DISPATCH_V2_TMPL THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL @@ -204,3 +219,9 @@ HeaderOnlyGenericPackedTensorAccessor # HeaderOnlyTensorAccessorBase and # HeaderOnlyGenericPackedTensorAccessorBase are tested through # HeaderOnlyTensorAccessor and HeaderOnlyGenericPackedTensorAccessor + +# torch/headeronly/util/Deprecated.h +# C10_DEPRECATED, C10_DEPRECATED_MESSAGE, and +# C10_DEFINE_DEPRECATED_USING functionalities are expressed at compile +# time that have no effect to runtime. Therefore, these macros are not +# tested under test/. diff --git a/torch/headeronly/core/Layout.h b/torch/headeronly/core/Layout.h new file mode 100644 index 0000000000000..62e34ff67b457 --- /dev/null +++ b/torch/headeronly/core/Layout.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include + +#include +#include + +namespace c10 { + +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +} // namespace c10 + +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) +using c10::kJagged; +using c10::kMkldnn; +using c10::kSparse; +using c10::kSparseBsc; +using c10::kSparseBsr; +using c10::kSparseCsc; +using c10::kSparseCsr; +using c10::kStrided; +using c10::Layout; +HIDDEN_NAMESPACE_END(torch, headeronly) diff --git a/torch/headeronly/core/MemoryFormat.h b/torch/headeronly/core/MemoryFormat.h new file mode 100644 index 0000000000000..ad02a901e0169 --- /dev/null +++ b/torch/headeronly/core/MemoryFormat.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include +#include + +// Memory format is not the property of a Tensor. It is the way to tell an +// operator how the result should be organized in memory and nothing more. That +// means memory format should never be used as return value for any tensor state +// interrogation functions (internally and externally). +// +// Possible options are: +// Preserve: +// If any of the input tensors is in channels_last format, operator output +// should be in channels_last format +// +// Contiguous: +// Regardless of input tensors format, the output should be contiguous +// Tensor. +// +// ChannelsLast: +// Regardless of input tensors format, the output should be in channels_last +// format. + +namespace c10 { + +enum class MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +inline MemoryFormat get_contiguous_memory_format() { + return MemoryFormat::Contiguous; +} + +} // namespace c10 + +HIDDEN_NAMESPACE_BEGIN(torch, headeronly) +using c10::get_contiguous_memory_format; +using c10::MemoryFormat; +HIDDEN_NAMESPACE_END(torch, headeronly) diff --git a/torch/headeronly/cpu/vec/intrinsics.h b/torch/headeronly/cpu/vec/intrinsics.h index 4342005e30f49..3cf427dae64bc 100644 --- a/torch/headeronly/cpu/vec/intrinsics.h +++ b/torch/headeronly/cpu/vec/intrinsics.h @@ -29,11 +29,6 @@ /* GCC-compatible compiler, targeting ARM with SVE */ #include #endif -#if defined(MISSING_ARM_VLD1) -#include -#elif defined(MISSING_ARM_VST1) -#include -#endif #elif defined(__GNUC__) && defined(__IWMMXT__) /* GCC-compatible compiler, targeting ARM with WMMX */ #include diff --git a/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h b/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h deleted file mode 100644 index b78841ead92e9..0000000000000 --- a/torch/headeronly/cpu/vec/vec256/missing_vld1_neon.h +++ /dev/null @@ -1,396 +0,0 @@ -/* Workaround for missing vld1_*_x2 and vst1_*_x2 intrinsics in gcc-7. */ - -__extension__ extern __inline uint8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u8_x2(const uint8_t* __a) { - uint8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s8_x2(const int8_t* __a) { - int8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u16_x2(const uint16_t* __a) { - uint16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s16_x2(const int16_t* __a) { - int16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u32_x2(const uint32_t* __a) { - uint32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s32_x2(const int32_t* __a) { - int32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_u64_x2(const uint64_t* __a) { - uint64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_s64_x2(const int64_t* __a) { - int64x1x2_t ret; - __builtin_aarch64_simd_oi __o; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f16_x2(const float16_t* __a) { - float16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float32x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f32_x2(const float32_t* __a) { - float32x2x2_t ret; - asm volatile("ld1 {%S0.2s - %T0.2s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_f64_x2(const float64_t* __a) { - float64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly8x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p8_x2(const poly8_t* __a) { - poly8x8x2_t ret; - asm volatile("ld1 {%S0.8b - %T0.8b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly16x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p16_x2(const poly16_t* __a) { - poly16x4x2_t ret; - asm volatile("ld1 {%S0.4h - %T0.4h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly64x1x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1_p64_x2(const poly64_t* __a) { - poly64x1x2_t ret; - asm volatile("ld1 {%S0.1d - %T0.1d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u8_x2(const uint8_t* __a) { - uint8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s8_x2(const int8_t* __a) { - int8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u16_x2(const uint16_t* __a) { - uint16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s16_x2(const int16_t* __a) { - int16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u32_x2(const uint32_t* __a) { - uint32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s32_x2(const int32_t* __a) { - int32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline uint64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_u64_x2(const uint64_t* __a) { - uint64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline int64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_s64_x2(const int64_t* __a) { - int64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f16_x2(const float16_t* __a) { - float16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float32x4x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f32_x2(const float32_t* __a) { - float32x4x2_t ret; - asm volatile("ld1 {%S0.4s - %T0.4s}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline float64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_f64_x2(const float64_t* __a) { - float64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly8x16x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p8_x2(const poly8_t* __a) { - poly8x16x2_t ret; - asm volatile("ld1 {%S0.16b - %T0.16b}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly16x8x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p16_x2(const poly16_t* __a) { - poly16x8x2_t ret; - asm volatile("ld1 {%S0.8h - %T0.8h}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -__extension__ extern __inline poly64x2x2_t - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vld1q_p64_x2(const poly64_t* __a) { - poly64x2x2_t ret; - asm volatile("ld1 {%S0.2d - %T0.2d}, %1" : "=w"(ret) : "Q"(*__a)); - return ret; -} - -/* vst1x2 */ - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s64_x2(int64_t* __a, int64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u64_x2(uint64_t* __a, uint64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f64_x2(float64_t* __a, float64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s8_x2(int8_t* __a, int8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p8_x2(poly8_t* __a, poly8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s16_x2(int16_t* __a, int16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p16_x2(poly16_t* __a, poly16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_s32_x2(int32_t* __a, int32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u8_x2(uint8_t* __a, uint8x8x2_t val) { - asm volatile("st1 {%S1.8b - %T1.8b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u16_x2(uint16_t* __a, uint16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_u32_x2(uint32_t* __a, uint32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f16_x2(float16_t* __a, float16x4x2_t val) { - asm volatile("st1 {%S1.4h - %T1.4h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_f32_x2(float32_t* __a, float32x2x2_t val) { - asm volatile("st1 {%S1.2s - %T1.2s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1_p64_x2(poly64_t* __a, poly64x1x2_t val) { - asm volatile("st1 {%S1.1d - %T1.1d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s8_x2(int8_t* __a, int8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p8_x2(poly8_t* __a, poly8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s16_x2(int16_t* __a, int16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p16_x2(poly16_t* __a, poly16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s32_x2(int32_t* __a, int32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_s64_x2(int64_t* __a, int64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u8_x2(uint8_t* __a, uint8x16x2_t val) { - asm volatile("st1 {%S1.16b - %T1.16b}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u16_x2(uint16_t* __a, uint16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u32_x2(uint32_t* __a, uint32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_u64_x2(uint64_t* __a, uint64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f16_x2(float16_t* __a, float16x8x2_t val) { - asm volatile("st1 {%S1.8h - %T1.8h}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f64_x2(float64_t* __a, float64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_p64_x2(poly64_t* __a, poly64x2x2_t val) { - asm volatile("st1 {%S1.2d - %T1.2d}, %0" : "=Q"(*__a) : "w"(val)); -} diff --git a/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h b/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h deleted file mode 100644 index 93f1110d808c6..0000000000000 --- a/torch/headeronly/cpu/vec/vec256/missing_vst1_neon.h +++ /dev/null @@ -1,7 +0,0 @@ -/* Workaround for missing vst1q_f32_x2 in gcc-8. */ - -__extension__ extern __inline void - __attribute__((__always_inline__, __gnu_inline__, __artificial__)) - vst1q_f32_x2(float32_t* __a, float32x4x2_t val) { - asm volatile("st1 {%S1.4s - %T1.4s}, %0" : "=Q"(*__a) : "w"(val)); -} diff --git a/torch/headeronly/util/Deprecated.h b/torch/headeronly/util/Deprecated.h new file mode 100644 index 0000000000000..88440a0242eb4 --- /dev/null +++ b/torch/headeronly/util/Deprecated.h @@ -0,0 +1,102 @@ +#pragma once + +/** + * This file provides portable macros for marking declarations + * as deprecated. You should generally use C10_DEPRECATED, + * except when marking 'using' declarations as deprecated, + * in which case you should use C10_DEFINE_DEPRECATED_USING + * (due to portability concerns). + */ + +// Sample usage: +// +// C10_DEPRECATED void bad_func(); +// struct C10_DEPRECATED BadStruct { +// ... +// }; + +// NB: __cplusplus doesn't work for MSVC, so for now MSVC always uses +// the "__declspec(deprecated)" implementation and not the C++14 +// "[[deprecated]]" attribute. We tried enabling "[[deprecated]]" for C++14 on +// MSVC, but ran into issues with some older MSVC versions. +#if (defined(__cplusplus) && __cplusplus >= 201402L) +#define C10_DEPRECATED [[deprecated]] +#define C10_DEPRECATED_MESSAGE(message) [[deprecated(message)]] +#elif defined(__GNUC__) +#define C10_DEPRECATED __attribute__((deprecated)) +// TODO Is there some way to implement this? +#define C10_DEPRECATED_MESSAGE(message) __attribute__((deprecated)) + +#elif defined(_MSC_VER) +#define C10_DEPRECATED __declspec(deprecated) +#define C10_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#else +#warning "You need to implement C10_DEPRECATED for this compiler" +#define C10_DEPRECATED +#endif + +// Sample usage: +// +// C10_DEFINE_DEPRECATED_USING(BadType, int) +// +// which is the portable version of +// +// using BadType [[deprecated]] = int; + +// technically [[deprecated]] syntax is from c++14 standard, but it works in +// many compilers. +#if defined(__has_cpp_attribute) +#if __has_cpp_attribute(deprecated) && !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#endif +#endif + +#if defined(_MSC_VER) +#if defined(__CUDACC__) +// neither [[deprecated]] nor __declspec(deprecated) work on nvcc on Windows; +// you get the error: +// +// error: attribute does not apply to any entity +// +// So we just turn the macro off in this case. +#if defined(C10_DEFINE_DEPRECATED_USING) +#undef C10_DEFINE_DEPRECATED_USING +#endif +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#else +// [[deprecated]] does work in windows without nvcc, though msc doesn't support +// `__has_cpp_attribute` when c++14 is supported, otherwise +// __declspec(deprecated) is used as the alternative. +#ifndef C10_DEFINE_DEPRECATED_USING +#if defined(_MSVC_LANG) && _MSVC_LANG >= 201402L +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName [[deprecated]] = TypeThingy; +#else +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = __declspec(deprecated) TypeThingy; +#endif +#endif +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) && defined(__GNUC__) +// nvcc has a bug where it doesn't understand __attribute__((deprecated)) +// declarations even when the host compiler supports it. We'll only use this gcc +// attribute when not cuda, and when using a GCC compiler that doesn't support +// the c++14 syntax we checked for above (available in __GNUC__ >= 5) +#if !defined(__CUDACC__) +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName __attribute__((deprecated)) = TypeThingy; +#else +// using cuda + gcc < 5, neither deprecated syntax is available so turning off. +#define C10_DEFINE_DEPRECATED_USING(TypeName, TypeThingy) \ + using TypeName = TypeThingy; +#endif +#endif + +#if !defined(C10_DEFINE_DEPRECATED_USING) +#warning "You need to implement C10_DEFINE_DEPRECATED_USING for this compiler" +#define C10_DEFINE_DEPRECATED_USING +#endif diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index 0f348590ea397..e5ddc1e443a29 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -85,9 +85,6 @@ def get_qualified_name(func): class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" - def __init__(self, store: CallTraceStore) -> None: - super().__init__(store) - def log(self, trace: CallTrace) -> None: # pyrefly: ignore [missing-attribute] self.traces.append(trace) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 75355cbd4b8e0..ec4bbd125119d 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -152,8 +152,7 @@ def _get_valid_constant(attr, v, owner_type): class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None: - super().__init__(source, filename, file_lineno, leading_whitespace_len) + pass def get_annotations(obj): diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index 47d082f44332f..00acf1782c2d8 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -1034,6 +1034,15 @@ std::ostream& operator<<(std::ostream& out, const Constant& constant) { out << kDevicePrefix << '{' << arg << '}'; } else if constexpr (is_same_v>) { out << fmt::format("[{}]", fmt::join(arg, ",")); + } else if constexpr (is_same_v>>) { + out << '['; + for (const auto& [idx, inner_list] : c10::enumerate(arg)) { + if (idx > 0) { + out << ", "; + } + out << fmt::format("{}", fmt::streamed(inner_list)); + } + out << ']'; } else if constexpr (is_same_v>) { out << fmt::format(""); VLOG(0) << "Subgraph pretty print is not implemented"; diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h index bbd87a8e2014b..c713df9401884 100644 --- a/torch/nativert/graph/Graph.h +++ b/torch/nativert/graph/Graph.h @@ -97,6 +97,7 @@ using Constant = std::variant< bool, std::vector, std::vector, + std::vector>, std::unique_ptr>; c10::IValue constantToIValue(const Constant& constant); diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index 4c45edd1f5751..532e73a40bd4a 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -101,6 +101,11 @@ Value* symbolicToValue( case torch::_export::Argument::Tag::AS_SYM_FLOAT: { return graph.getValue(arg.get_as_sym_float().get_as_name()); } + case torch::_export::Argument::Tag::AS_STRING_TO_ARGUMENT: { + TORCH_CHECK( + false, + "String to argument mapping is not yet supported in symbolic context"); + } default: TORCH_CHECK( false, @@ -453,6 +458,7 @@ bool isSymbolic(const torch::_export::Argument& arg) { case torch::_export::Argument::Tag::AS_SYM_FLOAT: case torch::_export::Argument::Tag::AS_SYM_FLOATS: case torch::_export::Argument::Tag::AS_CUSTOM_OBJ: + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSOR: return true; default: return false; @@ -532,6 +538,23 @@ Constant constantToValue( case torch::_export::Argument::Tag::AS_SYM_FLOATS: { TORCH_CHECK(false, "SymFloats is not yet implemented"); } + case torch::_export::Argument::Tag::AS_OPTIONAL_TENSOR: + TORCH_CHECK(false, "Optional tensor is symbolic, not constant"); + case torch::_export::Argument::Tag::AS_COMPLEX: + TORCH_CHECK(false, "Complex values are not yet supported as constants"); + case torch::_export::Argument::Tag::AS_INT_LISTS: { + std::vector> ret; + for (const auto& inner_list : jsonArg.get_as_int_lists()) { + std::vector inner_ret; + for (const auto& val : inner_list) { + inner_ret.push_back(val); + } + ret.push_back(inner_ret); + } + return ret; + } + case torch::_export::Argument::Tag::AS_STRING_TO_ARGUMENT: + return None(); default: TORCH_CHECK(false, "Got unknown json argument"); } diff --git a/torch/nativert/kernels/KernelHandlerRegistry.cpp b/torch/nativert/kernels/KernelHandlerRegistry.cpp index 3ac176a81bc3a..69655067a79c7 100644 --- a/torch/nativert/kernels/KernelHandlerRegistry.cpp +++ b/torch/nativert/kernels/KernelHandlerRegistry.cpp @@ -23,9 +23,10 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) { auto overloadName = selectScalarOverloadName(node); if (!overloadName.empty() && !c10::ends_with(node.target(), overloadName)) { - const std::string& newTarget = + const std::string newTarget = std::string(node.target()) - .replace(node.target().rfind('.'), std::string::npos, overloadName); + .replace( + node.target().rfind('.') + 1, std::string::npos, overloadName); LOG(INFO) << fmt::format( "Converting Tensor to {} for node: {} -> {}", overloadName, @@ -36,6 +37,11 @@ std::string maybeRevisedStaticDispatchTarget(const Node& node) { return std::string(node.target()); } +void updateNodeTargetIfNeeded(Node& node) { + auto newTarget = maybeRevisedStaticDispatchTarget(node); + node.setTarget(newTarget); +} + std::unique_ptr make_proxy_executor( const std::string& filename, bool is_cpu, @@ -69,6 +75,8 @@ void register_kernel_handlers() { const torch::nativert::ExecutorConfig& executorConfig, caffe2::serialize::PyTorchStreamReader* packageReader) -> std::pair { + updateNodeTargetIfNeeded(const_cast(node)); + return { torch::nativert::StaticallyDispatchedCPUKernelRegistry() ->Create(maybeRevisedStaticDispatchTarget(node), &node), diff --git a/torch/nativert/kernels/TritonKernel.cpp b/torch/nativert/kernels/TritonKernel.cpp index 081c81f7c646b..11dd671f8fbe6 100644 --- a/torch/nativert/kernels/TritonKernel.cpp +++ b/torch/nativert/kernels/TritonKernel.cpp @@ -167,8 +167,8 @@ void TritonKernel::computeInternal(ExecutionFrame& executionFrame) const { // todo: check if this is redundant auto out_t = out.toTensorList(); - for (const auto& i : output_indices_) { - out_t[i] = input(i, executionFrame).toTensor(); + for (const auto i : c10::irange(output_indices_.size())) { + out_t[i] = input(output_indices_[i], executionFrame).toTensor(); } } diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index f3746bcea1264..ad922227ccff8 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -7,10 +7,11 @@ import itertools import math import operator +import typing import warnings from collections.abc import Callable from enum import Enum -from typing import Any, NamedTuple +from typing import Any, Literal, NamedTuple, TypeAlias import torch from torch import Tensor @@ -82,6 +83,7 @@ def _warn_once( _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] +_Backend: TypeAlias = Literal["AUTO", "TRITON", "FLASH", "TRITON_DECODE"] # pyrefly: ignore [invalid-inheritance] @@ -219,12 +221,18 @@ class FlexKernelOptions(TypedDict, total=False): """ROCm-specific waves per execution unit.""" # pyrefly: ignore [invalid-annotation] - force_flash: NotRequired[bool] - """ If True, forces use of the cute-dsl flash attention kernel. - - Raises an error if flash attention cannot be used instead of falling back - to the default implementation. Useful for ensuring flash attention is used - when expected. + BACKEND: NotRequired[_Backend] + """Selects a specific kernel backend. + + Options: + - "AUTO": Use current heuristics (typically Triton-based kernels with + automatic selection between flex_attention and flex_decoding) + - "TRITON": Standard Triton flex_attention kernel + - "TRITON_DECODE": Triton flex_decoding kernel, only available for short sequence lengths with specific configurations + - "FLASH": Experimental: Flash Attention kernel (cute-dsl), user needs to have flash installed + + This option cannot be combined with legacy knobs such as ``FORCE_USE_FLEX_ATTENTION``. + Raises an error if the requested backend cannot be used. Default: "AUTO" """ @@ -1242,6 +1250,25 @@ def _apply_kernel_options( ): kernel_options = {} if kernel_options is None else dict(kernel_options) + if "BACKEND" in kernel_options and kernel_options.get( + "FORCE_USE_FLEX_ATTENTION", False + ): + # TODO: remove FORCE_USE_FLEX_ATTENTION once BACKEND is fully adopted. + raise RuntimeError( + "BACKEND cannot be combined with legacy FORCE_USE_FLEX_ATTENTION. " + "BACKEND supersedes the legacy knob; please drop FORCE_USE_FLEX_ATTENTION " + "and only specify the desired BACKEND." + ) + + if "BACKEND" in kernel_options: + valid_backends = typing.get_args(_Backend) + if kernel_options["BACKEND"] not in valid_backends: + raise ValueError( + f"Invalid BACKEND value '{kernel_options['BACKEND']}'. " + f"Must be one of {valid_backends}" + ) + + kernel_options.setdefault("BACKEND", "AUTO") kernel_options.setdefault("PRESCALE_QK", False) kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 07fec131d618a..d31b99a59de21 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2613,7 +2613,9 @@ def embedding_bag( :attr:`offsets`, if those are not None. include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. - The last element is the size of the input, or the ending index position of the last bag (sequence). + The last element is the size of the input, or the ending index position + of the last bag (sequence). This matches the CSR format. Ignored when + input is 2D. Default ``False``. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated @@ -2724,7 +2726,7 @@ def embedding_bag( offsets = torch.arange( 0, input.numel(), input.size(1), dtype=input.dtype, device=input.device ) - + include_last_offset = False input = input.reshape(-1) if per_sample_weights is not None: per_sample_weights = per_sample_weights.reshape(-1) @@ -6640,6 +6642,52 @@ def multi_head_attention_forward( return attn_output, None +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + *, + offs: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + r""" + grouped_mm(mat_a, mat_b, *, offs=None, bias=None, out_dtype=None) + + Computes a grouped matrix multiply that shares weight shapes across experts but + allows jagged token counts per expert, which is common in Mixture-of-Experts + (MoE) layers. Both ``mat_a`` and ``mat_b`` must be 2D or 3D tensors that already + satisfy the physical layout restrictions of grouped GEMM kernels (e.g., row-major + ``mat_a`` and column-major ``mat_b`` for FP8 inputs). Inputs are currently + expected to be ``torch.bfloat16`` values on CUDA devices with :math:`SM \ge 80`. + + Args: + mat_a: Left operand. When 2D, its leading dimension is sliced into groups + according to ``offs``. When 3D, its first dimension enumerates the groups + directly and ``offs`` must be ``None``. + mat_b: Right operand. When both operands are 2D (e.g., MoE weight-gradient + updates), the trailing dimension of ``mat_a`` and the leading dimension of + ``mat_b`` are partitioned according to the same ``offs`` tensor. For the + common forward pass (``out = input @ weight.T``) ``mat_b`` is 3D with + shape ``(num_groups, N, K)``. + offs: Optional 1D tensor of monotonically increasing ``int32`` offsets that + delimit the jagged dimension of any 2D operand. ``offs[i]`` marks the end + of group ``i`` and ``offs[-1]`` must be strictly less than the total + length of that operand's sliced dimension; elements beyond ``offs[-1]`` + are ignored. + bias: Optional tensor that is added to the grouped outputs. Bias is not + jagged and must be broadcastable to the result shape of each group. + out_dtype: Optional dtype that controls the accumulation/output dtype. + Passing ``torch.float32`` accumulates BF16 inputs in FP32 while keeping + the grouped GEMM API non-differentiable. + + Returns: + A tensor containing the concatenated results of each per-group GEMM with + shape inferred from the operands and ``offs``. + """ + + return torch._grouped_mm(mat_a, mat_b, offs=offs, bias=bias, out_dtype=out_dtype) + + def scaled_mm( mat_a: Tensor, mat_b: Tensor, diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 5a3e24b115df7..58e2d65a81175 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -733,6 +733,17 @@ def scaled_mm( __all__ += ["scaled_mm"] +def grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + *, + offs: Tensor | None = None, + bias: Tensor | None = None, + out_dtype: _dtype | None = None, +) -> Tensor: ... + +__all__ += ["grouped_mm"] + class SwizzleType(Enum): NO_SWIZZLE = 0 SWIZZLE_32_4_4 = 1 diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 60bd561bfd0e4..4a7302d5cae33 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -301,7 +301,9 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if num_channels % num_groups != 0: - raise ValueError("num_channels must be divisible by num_groups") + raise ValueError( + f"num_channels ({num_channels}) must be divisible by num_groups ({num_groups})" + ) self.num_groups = num_groups self.num_channels = num_channels diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index e3b8fafa6a274..83a8d6ef334bb 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -304,8 +304,10 @@ class EmbeddingBag(Module): sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See Notes for more details regarding sparse gradients. Note: this option is not supported when ``mode="max"``. - include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element - is equivalent to the size of `indices`. This matches the CSR format. + include_last_offset (bool, optional): if ``True``, the size of offsets is equal to the number of bags + 1. + The last element is the size of the input, or the ending index position + of the last bag (sequence). This matches the CSR format. Ignored when + input is 2D. Default ``False``. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". For a newly constructed diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index f1f1ac6c67e40..77e2e3049fb31 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -9,6 +9,7 @@ import logging import operator import pathlib +import sys import textwrap import traceback import typing @@ -182,6 +183,9 @@ def _get_cbytes(self): ).from_address(tensor.data_ptr()) def tobytes(self) -> bytes: + # On big-endian machines, call the super's tobytes() which returns a little-endian result. + if sys.byteorder == "big": + return super().tobytes() # Implement tobytes to support native PyTorch types so we can use types like bloat16 # Reading from memory directly is also more efficient because # it avoids copying to a NumPy array @@ -189,6 +193,9 @@ def tobytes(self) -> bytes: return bytes(data) def tofile(self, file) -> None: + # On big-endian machines, call the super's tofile() which returns a little-endian result. + if sys.byteorder == "big": + return super().tofile(file) _, data = self._get_cbytes() return file.write(data) diff --git a/torch/overrides.py b/torch/overrides.py index dea75f69ea49b..22dfb67b825cc 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -252,6 +252,7 @@ def get_ignored_functions() -> set[Callable]: torch.nn.functional.has_torch_function_unary, torch.nn.functional.has_torch_function_variadic, torch.nn.functional.handle_torch_function, + torch.nn.functional.grouped_mm, torch.nn.functional.scaled_grouped_mm, torch.nn.functional.scaled_mm, torch.nn.functional.sigmoid, diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index a66c14adfe86f..a4d8e7f752505 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # pyrefly: ignore [missing-module-attribute] +import sys from pickle import ( # type: ignore[attr-defined] _compat_pickle, _extension_registry, @@ -64,7 +65,19 @@ def save_global(self, obj, name=None): raise PicklingError(f"Can't pickle {obj}: {str(err)}") from err module = self.importer.import_module(module_name) - _, parent = _getattribute(module, name) + if sys.version_info >= (3, 14): + # pickle._getattribute signature changes in 3.14 + # to take iterable and return just the object (not tuple) + # We need to get the parent object that contains the attribute + name_parts = name.split(".") + if "" in name_parts: + raise PicklingError(f"Can't pickle local object {obj!r}") + if len(name_parts) == 1: + parent = module + else: + parent = _getattribute(module, name_parts[:-1]) + else: + _, parent = _getattribute(module, name) # END CHANGED if self.proto >= 2: # type: ignore[attr-defined] diff --git a/torch/package/importer.py b/torch/package/importer.py index fc0e735890634..83a896c69a629 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -181,6 +181,12 @@ def import_module(self, module_name: str): return importlib.import_module(module_name) def whichmodule(self, obj: Any, name: str) -> str: + # In Python 3.14+, pickle.whichmodule tries to import the module, + # which fails for mangled package names like ''. + # Check __module__ first before calling pickle.whichmodule. + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return module_name return _pickle_whichmodule(obj, name) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f3400e438a2d3..c52bd0f9ce2bb 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -9,7 +9,7 @@ from enum import Enum from functools import partial from typing import Any, Optional -from typing_extensions import Self +from typing_extensions import deprecated, Self from warnings import warn import torch @@ -38,6 +38,14 @@ ] PROFILER_STEP_NAME = "ProfilerStep" +_WARNINGS_SHOWN = set() + + +def _warn_once(msg, category=UserWarning, stacklevel=2): + if msg not in _WARNINGS_SHOWN: + _WARNINGS_SHOWN.add(msg) + warn(msg, category=category, stacklevel=stacklevel) + class _NumpyEncoder(json.JSONEncoder): """ @@ -205,6 +213,12 @@ def prepare_trace(self) -> None: acc_events=self.acc_events, custom_trace_id_callback=self.custom_trace_id_callback, ) + if (self.profiler is not None) and (not self.acc_events): + _warn_once( + "Warning: Profiler clears events at the end of each cycle." + "Only events from the current cycle will be reported." + "To keep events across cycles, set acc_events=True." + ) self.profiler._prepare_trace() def start_trace(self) -> None: @@ -408,6 +422,11 @@ def _memory_profile(self) -> MemoryProfile: ) return MemoryProfile(self.profiler.kineto_results) + @deprecated( + "`export_memory_timeline` is deprecated and will be removed in a future version. " + "Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.", + category=FutureWarning, + ) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: """Export memory event information from the profiler collected tree for a given device, and export a timeline plot. There are 3 @@ -429,6 +448,11 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non ``torch.profiler._memory_profiler.Category``. Output: Memory timeline written as gzipped JSON, JSON, or HTML. + + .. deprecated:: + ``export_memory_timeline`` is deprecated and will be removed in a future version. + Please use ``torch.cuda.memory._record_memory_history`` and + ``torch.cuda.memory._export_memory_snapshot`` instead. """ # Default to device 0, if unset. Fallback on cpu. if device is None: diff --git a/torch/random.py b/torch/random.py index cf23e52db320e..f86d7349019dc 100644 --- a/torch/random.py +++ b/torch/random.py @@ -39,6 +39,10 @@ def manual_seed(seed) -> torch._C.Generator: is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ + return _manual_seed_impl(seed, update_local_tensor_states=True) + + +def _manual_seed_impl(seed, update_local_tensor_states) -> torch._C.Generator: seed = int(seed) import torch.cuda diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 4b731e7f7c37f..5481cd0a53ee7 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -89,7 +89,7 @@ def evaluate_platform_supports_green_context(): driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run) if driver_version is None: return False - return int(driver_version.split('.')[0]) >= 550 + return int(driver_version.split('.')[0]) >= 570 PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention()) PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention()) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6724ab2ae739a..0cf0f50c23ef5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6185,6 +6185,9 @@ def sample_inputs_std_var(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(tensor_nd(), dim=(1,), correction=S // 2) yield SampleInput(tensor_nd(), dim=None, correction=0, keepdim=True) yield SampleInput(tensor_nd(), dim=None, correction=None) + yield SampleInput(tensor_nd(), dim=None, correction=-1) + yield SampleInput(tensor_nd(), dim=None, correction=-5) + yield SampleInput(tensor_nd(), correction=0.5, keepdim=True) yield SampleInput(tensor_nd(), correction=0, keepdim=True) yield SampleInput(make_tensor(3, 4, 5, device=device, dtype=dtype, requires_grad=requires_grad), dim=-3) @@ -11701,11 +11704,11 @@ def reference_mse_loss(input, target, reduction="mean"): return se -def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, bias=None, eps=1e-5): +def reference_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, bias=None, eps=1e-5): return reference_native_layer_norm(inp, normalized_shape, weight, bias, eps)[0] -def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight, bias, eps): +def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight, bias, eps): feature_size = np.prod(normalized_shape) inp_view = inp.reshape(-1, feature_size) # type: ignore[call-overload] mean = inp_view.mean(axis=-1, keepdims=True) @@ -11722,7 +11725,7 @@ def reference_native_layer_norm(inp: npt.NDArray, normalized_shape: tuple[int], return Y.reshape(*inp.shape), mean.reshape(stat_shape), (1.0 / np.sqrt(var + eps)).reshape(stat_shape) -def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int], weight=None, eps=None): +def reference_rms_norm(inp: npt.NDArray, normalized_shape: tuple[int, ...], weight=None, eps=None): if eps is None: eps = torch.finfo(numpy_to_torch_dtype(inp.dtype)).eps feature_size = np.prod(normalized_shape) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 9571cc1209ed6..83fca0b973856 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -1769,6 +1769,32 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, ] +def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): + """ + Error inputs for GroupNorm that test error messages include actual values. + """ + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(3, 10), # num_groups=3, num_channels=10 + forward_input=FunctionInput(), # Not needed for construction error + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex=r"num_channels \(10\) must be divisible by num_groups \(3\)" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(5, 13), # num_groups=5, num_channels=13 + forward_input=FunctionInput(), # Not needed for construction error + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=ValueError, + error_regex=r"num_channels \(13\) must be divisible by num_groups \(5\)" + ), + ] + + def module_inputs_torch_nn_Hardshrink(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3958,6 +3984,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.GroupNorm, module_inputs_func=module_inputs_torch_nn_GroupNorm, + module_error_inputs_func=module_error_inputs_torch_nn_GroupNorm, dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=True), skips=( # Tracking at https://github.com/pytorch/pytorch/issues/98089 diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index d5afc413daed8..815cc8859080f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1370,8 +1370,6 @@ class XMLTestResultVerbose(_XMLTestResult): This works with unittest_xml_reporting<=3.2.0,>=2.0.0 (3.2.0 is latest at the moment) """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) def addSkip(self, test, reason): super().addSkip(test, reason) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 6ce7d4b2ca507..1f6c4aece1e80 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -386,7 +386,7 @@ def device_type(self) -> str: @property def backend(self) -> str: - backend = dist.get_default_backend_for_device(DEVICE_TYPE) + backend = dist.get_default_backend_for_device(self.device_type) return backend def init_manual_seed_for_rank(self) -> None: @@ -724,6 +724,9 @@ def setUp(self) -> None: torch.autograd._enable_record_function(False) def tearDown(self) -> None: + from torch.distributed.tensor import _random as random + + random._rng_tracker = None super().tearDown() torch.autograd._enable_record_function(True) diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 6bd34c812d641..cda1908a3a340 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -34,7 +34,7 @@ ) from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._helion import has_helion -from torch.utils._pallas import has_pallas +from torch.utils._pallas import has_pallas, has_tpu_pallas from torch.utils._triton import has_triton from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( @@ -105,6 +105,8 @@ def test_cpu(): getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases ) +HAS_TPU = has_tpu_pallas() + def _check_has_dynamic_shape( self: TestCase, diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index ae5a468ddd6ae..da75f82815507 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -382,9 +382,6 @@ def sample_inputs_linalg_norm( elif is_matrix_norm: dims_to_check = { None: (0,), - np.inf: (0,), - 2: (0, 1), - 1: (1,), -1: (1,), -2: (0, 1), -np.inf: (0,), @@ -395,6 +392,18 @@ def sample_inputs_linalg_norm( # have non-zero size. continue + no_grad_dims_to_check = { + np.inf: (0,), + 2: (0, 1), + 1: (1,), + }.get(ord, ()) + + if ( + any(test_size[d] == 0 for d in no_grad_dims_to_check) + and requires_grad + ): + continue + if variant == "subgradient_at_zero": yield SampleInput( torch.zeros( diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 745b05d1904d7..0b853997261a9 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -924,7 +924,9 @@ def dispatch_hook(func, types, args, kwargs, result): @staticmethod @contextlib.contextmanager def log_tensor_hashes( - hash_fn: Union[Callable, str, list[str]] = "norm", hash_inputs: bool = False + hash_fn: Union[Callable, str, list[str]] = "norm", + hash_inputs: bool = False, + wait_on_collectives: bool = True, ): """ Installs hook for tensor hash logging. @@ -936,6 +938,7 @@ def log_tensor_hashes( - "hash_tensor": uses torch.hash_tensor (XOR sum reduction) - List of strings: returns tuple of hashes from above options hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash". + wait_on_collectives: if True (default), waits on async collective Work handles before hashing. NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes. """ @@ -966,6 +969,12 @@ def _dispatch_hash_hook(func, types, args, kwargs, result): if "empty" in str(func) or "profiler" in str(func): return None + # Wait on async collective Work handles before hashing + if wait_on_collectives and isinstance(result, (tuple, list)): + for item in result: + if isinstance(item, torch.ScriptObject) and hasattr(item, "wait"): + item.wait() + out = {} out["hash"] = _tree_hash(result) if hash_inputs: diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py index 2d93e7f32c58e..63ef22be49cf4 100644 --- a/torch/utils/_pallas.py +++ b/torch/utils/_pallas.py @@ -72,6 +72,18 @@ def has_jax_tpu_backend() -> bool: return False +@functools.cache +def has_tpu_pallas() -> bool: + """Checks for a full Pallas-on-TPU environment.""" + return has_pallas_package() and has_jax_tpu_backend() + + +@functools.cache +def has_cuda_pallas() -> bool: + """Checks for a full Pallas-on-CUDA environment.""" + return has_pallas_package() and torch.cuda.is_available() and has_jax_cuda_backend() + + @functools.cache def has_pallas() -> bool: """ @@ -82,20 +94,4 @@ def has_pallas() -> bool: - Pallas (jax.experimental.pallas) available - A compatible backend (CUDA or TPU) is available in both PyTorch and JAX. """ - if not has_pallas_package(): - return False - - # Check for is CUDA is available or if JAX has GPU/CUDA backend - has_cuda = torch.cuda.is_available() and has_jax_cuda_backend() - - # Check for TPU backend - has_tpu_torch = False - try: - import torch_xla.core.xla_model as xm - - has_tpu_torch = xm.xla_device_count() > 0 - except ImportError: - pass - has_tpu = has_tpu_torch and has_jax_tpu_backend() - - return has_cuda or has_tpu + return has_cuda_pallas() or has_tpu_pallas() diff --git a/torch/utils/_runtime_estimation.py b/torch/utils/_runtime_estimation.py new file mode 100644 index 0000000000000..fcda7cceaee48 --- /dev/null +++ b/torch/utils/_runtime_estimation.py @@ -0,0 +1,151 @@ +import math +import os + +import torch +from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from torch.utils._ordered_set import OrderedSet + +from .flop_counter import flop_registry + + +aten = torch.ops.aten + +_FLOAT_TYPES = OrderedSet( + [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + ] +) + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = OrderedSet( + [ + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, + ] +) +# We can ignore benchmarking tensor create ops +_CREATE_OPS = OrderedSet( + [ + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, + ] +) + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + + +def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert len(out_dtypes) == 1, ( + f"Only support single out dtype got {out_dtypes} for {func_packet}" + ) + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + +def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem_consumed + + +def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 5d0f252d95b30..b8eb98be15a11 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -9239,6 +9239,8 @@ API_PYTORCH, ), ), + ("cuda::CUDAEvent", ("hip::HIPEventMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAEvent", ("HIPEventMasqueradingAsCUDA", API_PYTORCH)), ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), ( @@ -9293,6 +9295,14 @@ "c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH), ), + ( + "ATen/cuda/CUDAEvent.h", # To keep BC, we have to keep this mapping + ("ATen/hip/HIPEvent.h", API_PYTORCH), + ), + ( + "c10/cuda/CUDAEvent.h", + ("ATen/hip/impl/HIPEventMasqueradingAsCUDA.h", API_PYTORCH), + ), ( "c10/cuda/CUDAStream.h", ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH), @@ -9433,6 +9443,7 @@ ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), + ("c10/cuda/CUDAEvent.h", ("c10/hip/HIPEvent.h", API_C10)), ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), ("c10/cuda/CUDAAllocatorConfig.h", ("c10/hip/HIPAllocatorConfig.h", API_C10)), diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 8abb547d500f8..df4bf34db2114 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -249,6 +249,8 @@ def format_sequence(obj): if len(filename) > FRAME_FILENAME_LIMIT: filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] return f"frame\n{filename}:{obj.f_lineno}" + elif is_cuda_tensor(obj): + return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})" else: return f"object\n{type(obj).__module__}.{type(obj).__name__}" diff --git a/torch/version.py.tpl b/torch/version.py.tpl index 1b7eab07ac949..ee37a91b7ffdc 100644 --- a/torch/version.py.tpl +++ b/torch/version.py.tpl @@ -4,6 +4,7 @@ cuda = '{{CUDA_VERSION}}' # TODO: use workspace status to stamp the correct version git_version = "" hip = None +rocm = None # This is a gross monkey-patch hack that depends on the order of imports # in torch/__init__.py diff --git a/torchgen/_autoheuristic/ah_tree.py b/torchgen/_autoheuristic/ah_tree.py index c2ec2b8d94788..0afc8751e6b82 100644 --- a/torchgen/_autoheuristic/ah_tree.py +++ b/torchgen/_autoheuristic/ah_tree.py @@ -7,8 +7,8 @@ class DecisionTreeNode: def __init__( self, - feature: Optional[str] = None, - threshold: Optional[float] = None, + feature: str | None = None, + threshold: float | None = None, left: Optional["DecisionTreeNode"] = None, right: Optional["DecisionTreeNode"] = None, class_probs: Any = None, diff --git a/torchgen/context.py b/torchgen/context.py index e3725d66b9643..a99d7119c656f 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -2,7 +2,7 @@ import contextlib import functools -from typing import Any, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, TYPE_CHECKING, TypeVar import torchgen.local as local from torchgen.model import ( @@ -26,15 +26,15 @@ NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, - Union[NativeFunction, NativeFunctionsGroup], - Union[NativeFunction, NativeFunctionsViewGroup], + NativeFunction | NativeFunctionsGroup, + NativeFunction | NativeFunctionsViewGroup, ) F2 = TypeVar( "F2", NativeFunction, NativeFunctionsGroup, - Optional[NativeFunction], + NativeFunction | None, bool, str, ) diff --git a/torchgen/gen.py b/torchgen/gen.py index ae0e4b52a0fc8..2bc9ed6996705 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -97,7 +97,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from typing import Optional T = TypeVar("T") @@ -2218,7 +2217,7 @@ def gen_source_files( per_operator_headers: bool, skip_dispatcher_op_registration: bool, update_aoti_c_shim: bool, - aoti_backends: set[Optional[DispatchKey]], + aoti_backends: set[DispatchKey | None], extend_aoti_c_shim: bool, ) -> None: extra_cuda_headers = """\ diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index ead2a2a1cf4cc..e0724f6c3959b 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -31,7 +31,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - from typing import Optional base_type_to_c_type = { @@ -393,7 +392,7 @@ def gen_static_dispatch_backend_call_signature( def gen_static_dispatch_backend_call( f: NativeFunction, - backend_index: Optional[BackendIndex] = None, + backend_index: BackendIndex | None = None, ) -> str: sig = DispatcherSignature.from_schema(f.func) cpp_sig = gen_static_dispatch_backend_call_signature(sig, f) @@ -421,7 +420,7 @@ def gen_static_dispatch_backend_call( def get_backend_index_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> BackendIndex | None: @@ -463,7 +462,7 @@ def get_backend_index_for_aoti( def get_header_for_aoti( func: NativeFunction, func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], extend_aoti_c_shim: bool, ) -> str | None: @@ -490,7 +489,7 @@ def gen_c_shim( func: NativeFunction, version_info: dict[str, list[str]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -528,7 +527,7 @@ def gen_c_shim( class ShimGenerator: inductor_fallback_ops: dict[str, dict[str, list[str]]] func_group_mapping: dict[OperatorName, NativeFunctionsGroup] - dispatch_key: Optional[DispatchKey] + dispatch_key: DispatchKey | None backend_indices: dict[DispatchKey, BackendIndex] header: bool # True to generate .h and False to generate .cpp extend_aoti_c_shim: bool @@ -555,7 +554,7 @@ def gen_aoti_c_shim( native_functions: Sequence[NativeFunction], inductor_fallback_ops: dict[str, dict[str, list[str]]], func_group_mapping: dict[OperatorName, NativeFunctionsGroup], - dispatch_key: Optional[DispatchKey], + dispatch_key: DispatchKey | None, backend_indices: dict[DispatchKey, BackendIndex], header: bool, extend_aoti_c_shim: bool, @@ -646,7 +645,7 @@ def gen_aoti_c_shim( def gen_aoti_c_shim_files( aoti_fm: FileManager, - aoti_backends: set[Optional[DispatchKey]], + aoti_backends: set[DispatchKey | None], native_functions: Sequence[NativeFunction], backend_indices: dict[DispatchKey, BackendIndex], structured_native_functions: Sequence[NativeFunctionsGroup], diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 1cb681ba19d34..0ef91332df9ff 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torchgen.api import cpp, dispatcher, functionalization from torchgen.api.translate import translate @@ -928,7 +928,7 @@ def new(self, out_index: str = "0") -> str: def map( g: NativeFunctionsViewGroup, run: Callable[[ViewMetaSpecialization], list[str]] ) -> list[str]: - def maybe_run(f: Optional[NativeFunction]) -> list[str]: + def maybe_run(f: NativeFunction | None) -> list[str]: if f is None: return [] with native_function_manager(f): diff --git a/torchgen/gen_schema_utils.py b/torchgen/gen_schema_utils.py index b81c91527baa1..1238a5a5a3933 100644 --- a/torchgen/gen_schema_utils.py +++ b/torchgen/gen_schema_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any from torchgen.model import ( Annotation, @@ -29,7 +29,7 @@ class TypeGen: } @staticmethod - def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: + def from_example(obj: Any) -> BaseType | ListType | CustomClassType: import torch if isinstance(obj, torch.fx.GraphModule): @@ -61,7 +61,7 @@ def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: class ReturnGen: @staticmethod def from_example( - name: Optional[str], obj: Any, annotation: Optional[Annotation] + name: str | None, obj: Any, annotation: Annotation | None ) -> Return: return Return(name, TypeGen.from_example(obj), annotation) @@ -69,7 +69,7 @@ def from_example( class ArgumentGen: @staticmethod def from_example( - name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] + name: str, obj: Any, default: str | None, annotation: Annotation | None ) -> Argument: return Argument( name, TypeGen.from_example(obj), default=default, annotation=annotation diff --git a/torchgen/model.py b/torchgen/model.py index 906b61e2f19cc..7971b893e7585 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -5,7 +5,7 @@ import re from dataclasses import dataclass from enum import auto, Enum -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from typing_extensions import assert_never from torchgen.utils import NamespaceHelper, OrderedSet @@ -2563,7 +2563,7 @@ class BaseOperatorName: # as part of the base operator name, for __str__() to consume. # The canonical input (from the rest of the infra) will not contain namespace, but # we have a usecase in ExecuTorch where we want to support BaseOperatorName with namespace. - namespace: Optional[str] = None + namespace: str | None = None @staticmethod def parse(op: str) -> BaseOperatorName: diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index e35221c3f50eb..d6909bc4d7f67 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -3,7 +3,7 @@ import argparse import itertools import os -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar from libfb.py.log import set_simple_logging # type: ignore[import] @@ -23,7 +23,7 @@ NativeGroupT = TypeVar( "NativeGroupT", - bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], + bound=NativeFunctionsGroup | NativeFunctionsViewGroup, )