diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index ce7803cf9acd..3bc3fd8badc6 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 -ARG DEVTOOLSET_VERSION=11 +ARG DEVTOOLSET_VERSION=13 RUN yum -y update RUN yum -y install epel-release # install glibc-langpack-en make sure en_US.UTF-8 locale is available RUN yum -y install glibc-langpack-en -RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain +RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb # Just add everything as a safe.directory for git since these will be used in multiple places with git RUN git config --global --add safe.directory '*' ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH @@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh # Install CUDA FROM base as cuda ARG CUDA_VERSION=12.6 +ARG DEVTOOLSET_VERSION=13 RUN rm -rf /usr/local/cuda-* ADD ./common/install_cuda.sh install_cuda.sh COPY ./common/install_nccl.sh install_nccl.sh @@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION} # Preserve CUDA_VERSION for the builds ENV CUDA_VERSION=${CUDA_VERSION} # Make things in our path by default -ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH +ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + FROM cuda as cuda12.6 RUN bash ./install_cuda.sh 12.6 @@ -68,8 +70,22 @@ FROM cuda as cuda13.0 RUN bash ./install_cuda.sh 13.0 ENV DESIRED_CUDA=13.0 -FROM ${ROCM_IMAGE} as rocm +FROM ${ROCM_IMAGE} as rocm_base +ARG DEVTOOLSET_VERSION=13 +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 +# Install devtoolset on ROCm base image +RUN yum -y update && \ + yum -y install epel-release && \ + yum -y install glibc-langpack-en && \ + yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb +RUN git config --global --add safe.directory '*' +ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + +FROM rocm_base as rocm ARG PYTORCH_ROCM_ARCH +ARG DEVTOOLSET_VERSION=13 ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ADD ./common/install_mkl.sh install_mkl.sh RUN bash ./install_mkl.sh && rm install_mkl.sh @@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0 # Final step FROM ${BASE_TARGET} as final +ARG DEVTOOLSET_VERSION=13 COPY --from=openssl /opt/openssl /opt/openssl COPY --from=patchelf /patchelf /usr/local/bin/patchelf COPY --from=conda /opt/conda /opt/conda diff --git a/.ci/docker/almalinux/build.sh b/.ci/docker/almalinux/build.sh index ad234ce1ffb9..885c4440e0e6 100755 --- a/.ci/docker/almalinux/build.sh +++ b/.ci/docker/almalinux/build.sh @@ -63,7 +63,7 @@ docker build \ --target final \ --progress plain \ --build-arg "BASE_TARGET=${BASE_TARGET}" \ - --build-arg "DEVTOOLSET_VERSION=11" \ + --build-arg "DEVTOOLSET_VERSION=13" \ ${EXTRA_BUILD_ARGS} \ -t ${tmp_tag} \ $@ \ diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d0500b89780c..f0b9a788758c 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -261,9 +261,9 @@ case "$tag" in PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11) + pytorch-linux-jammy-aarch64-py3.10-gcc13) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes @@ -271,9 +271,19 @@ case "$tag" in # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks) + pytorch-linux-jammy-aarch64-py3.10-clang21) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + CLANG_VERSION=21 + ACL=yes + VISION=yes + OPENBLAS=yes + # snadampal: skipping llvm src build install because the current version + # from pytorch/llvm:9.0.1 is x86 specific + SKIP_LLVM_SRC_BUILD_INSTALL=yes + ;; + pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks) + ANACONDA_PYTHON_VERSION=3.10 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index d893bdd32ab3..2bc3043f3008 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1,5 @@ +<<<<<<< HEAD ac80c4190aa0321f761a08af97e1e1eee41f01d9 +======= +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 +>>>>>>> upstream/main diff --git a/.ci/docker/common/install_clang.sh b/.ci/docker/common/install_clang.sh index 1cb216edf1b3..93daeee919b3 100755 --- a/.ci/docker/common/install_clang.sh +++ b/.ci/docker/common/install_clang.sh @@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then # work around ubuntu apt-get conflicts sudo apt-get -y -f install wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - if [[ $CLANG_VERSION == 18 ]]; then - apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main" + if [[ $CLANG_VERSION -ge 18 ]]; then + apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main" fi fi diff --git a/.ci/docker/common/install_gcc.sh b/.ci/docker/common/install_gcc.sh index 3b96bf6e0ed2..df1c059bc386 100644 --- a/.ci/docker/common/install_gcc.sh +++ b/.ci/docker/common/install_gcc.sh @@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then # Need the official toolchain repo to get alternate packages add-apt-repository ppa:ubuntu-toolchain-r/test apt-get update - apt-get install -y g++-$GCC_VERSION + apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 - + update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50 # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.ci/docker/common/install_openblas.sh b/.ci/docker/common/install_openblas.sh index 2f386c6bd523..5a2806878124 100755 --- a/.ci/docker/common/install_openblas.sh +++ b/.ci/docker/common/install_openblas.sh @@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" - OPENBLAS_CHECKOUT_DIR="OpenBLAS" OPENBLAS_BUILD_FLAGS=" +CC=gcc NUM_THREADS=128 USE_OPENMP=1 NO_SHARED=0 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571d..d5c0c9914289 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.0 +3.5.1 diff --git a/.ci/magma-rocm/build_magma.sh b/.ci/magma-rocm/build_magma.sh index 7d95fed873dc..c7c7780227ea 100755 --- a/.ci/magma-rocm/build_magma.sh +++ b/.ci/magma-rocm/build_magma.sh @@ -6,8 +6,8 @@ set -eou pipefail # The script expects DESIRED_CUDA and PACKAGE_NAME to be set ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -# post merge of https://github.com/icl-utk-edu/magma/pull/65 -MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f +# https://github.com/icl-utk-edu/magma/pull/65 +MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec # Folders for the build PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata @@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE # Fetch magma sources and verify checksum pushd ${PACKAGE_DIR} -git clone https://github.com/icl-utk-edu/magma +git clone https://github.com/jeffdaily/magma pushd magma git checkout ${MAGMA_VERSION} popd diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae257875893..26996b5a32d5 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.github/ISSUE_TEMPLATE/release-feature-request.yml b/.github/ISSUE_TEMPLATE/release-feature-request.yml index 80f10807ae56..090a41d1942f 100644 --- a/.github/ISSUE_TEMPLATE/release-feature-request.yml +++ b/.github/ISSUE_TEMPLATE/release-feature-request.yml @@ -1,11 +1,11 @@ -name: πŸš€ Release highlight for proposed Feature +name: πŸš€ New Feature for Release description: Submit a Release highlight for proposed Feature labels: ["release-feature-request"] body: - type: textarea attributes: - label: Release highlight for proposed Feature + label: New Feature for Release description: > Example: β€œA torch.special module, analogous to SciPy's special module.” - type: input diff --git a/.github/actions/pytest-cache-download/action.yml b/.github/actions/pytest-cache-download/action.yml index 1406f962c4ca..3f51f6a5525b 100644 --- a/.github/actions/pytest-cache-download/action.yml +++ b/.github/actions/pytest-cache-download/action.yml @@ -38,9 +38,9 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --download \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --temp_dir $RUNNER_TEMP \ - --repo $REPO \ - --bucket $BUCKET \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --temp_dir "$RUNNER_TEMP" \ + --repo "$REPO" \ + --bucket "$BUCKET" \ diff --git a/.github/actions/pytest-cache-upload/action.yml b/.github/actions/pytest-cache-upload/action.yml index 2652d019075f..9fbb63a760f2 100644 --- a/.github/actions/pytest-cache-upload/action.yml +++ b/.github/actions/pytest-cache-upload/action.yml @@ -47,11 +47,11 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --upload \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --sha $SHA \ - --test_config $TEST_CONFIG \ - --shard $SHARD \ - --repo $REPO \ - --temp_dir $RUNNER_TEMP \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --sha "$SHA" \ + --test_config "$TEST_CONFIG" \ + --shard "$SHARD" \ + --repo "$REPO" \ + --temp_dir "$RUNNER_TEMP" \ diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 966f6bcfc0d9..14144f3c11e2 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2 +ad5816f0eee1c873df1b7d371c69f1f811a89387 diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 000000000000..06c3f32abd5e --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,125 @@ +# PyTorch Copilot Instructions + +This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively. + +## Architecture Overview + +### Core Components + +- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality +- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd + - `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse) + - `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry +- **torch/** - Python bindings and public API + - `torch/csrc/` - C++ Python bindings (hand-written and generated) + - `torch/csrc/autograd/` - Reverse-mode automatic differentiation + - `torch/csrc/jit/` - TorchScript JIT compiler +- **torchgen/** - Code generation tooling that reads `native_functions.yaml` +- **tools/** - Build scripts, autograd derivatives, code generation + +### The Code Generation Workflow + +**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file: +1. Declares operator signatures, variants (function/method), and dispatch behavior +2. Gets processed by `torchgen/` to generate C++/Python bindings +3. Produces headers in `build/aten/src/ATen/` during compilation + +Example entry structure: +```yaml +- func: my_op(Tensor self, Scalar alpha=1) -> Tensor + variants: function, method + dispatch: + CPU: my_op_cpu + CUDA: my_op_cuda +``` + +After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`). + +## Development Workflows + +### Building from Source + +**Never run `setup.py` directly** - use pip with editable install: +```bash +python -m pip install --no-build-isolation -v -e . +``` + +Speed up builds: +- `DEBUG=1` - Debug symbols with `-g -O0` +- `USE_CUDA=0` - Skip CUDA compilation +- `BUILD_TEST=0` - Skip C++ test binaries +- Install `ninja` (`pip install ninja`) for faster builds +- Use `ccache` for incremental compilation caching + +Rebuild specific targets: `(cd build && ninja )` + +### Testing + +**Critical**: DO NOT run entire test suites. Run specific tests only: +```bash +python test/test_torch.py TestTorch.test_specific_case +``` + +**Test structure**: All tests use `torch.testing._internal.common_utils`: +```python +from torch.testing._internal.common_utils import run_tests, TestCase + +class TestFeature(TestCase): + def test_something(self): + # Use self.assertEqual for tensor comparisons + pass + +if __name__ == "__main__": + run_tests() +``` + +**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file. + +### Linting + +Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes) + +## Project-Specific Conventions + +### Memory and Storage +- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs) +- CUDA device info lives in storage objects + +### Python-C++ Integration (`torch/csrc/`) +- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors +- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr` +- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion + +### Dispatch System +- PyTorch uses operator dispatch to route calls to backend-specific kernels +- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops +- See `aten/src/ATen/native/README.md` for dispatch keyword guidance + +## Git Workflow (AI Agent Specific) + +When preparing PRs from this environment: +```bash +git stash -u +git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch +git stash pop +# Resolve conflicts if necessary +``` + +## Common Gotchas + +1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml` +2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI +3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux) +4. **No internet access** - DO NOT attempt to install dependencies during development + +## Key Files Reference + +- `AGENTS.md` - Instructions specific to AI coding agents +- `CONTRIBUTING.md` - Comprehensive human contributor guide +- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript) +- `aten/src/ATen/native/README.md` - Operator implementation guide +- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd + +## Performance Debugging + +Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation. diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 43ed76a63cc6..608aeba53e6d 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -97,8 +97,8 @@ jobs: shell: bash run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') - if [[ $ngpu -lt 4 ]]; then - echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" + if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus. + echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs" exit 1 fi diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index e68bc6ead3a2..d27325b8a63d 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -344,5 +344,21 @@ jobs: if-no-files-found: ignore path: ./**/core.[1-9]* + - name: Authenticate with AWS + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results + # The max duration enforced by the server side + role-duration-seconds: 18000 + aws-region: us-east-1 + + - name: Upload the benchmark results + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: test/test-reports + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Teardown XPU uses: ./.github/actions/teardown-xpu diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 6fbe2e846d40..941a045649f3 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -77,9 +77,11 @@ jobs: pytorch-linux-noble-riscv64-py3.12-gcc14 ] include: - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 + runner: linux.arm64.m7g.4xlarge + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 # Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358 diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index e16c8be79130..46a1966570c6 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -72,7 +72,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.arm64.m7g.4xlarge build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" }, diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 6ab276a57fc4..3ce917567aec 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -115,10 +115,10 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, ]} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 2616141c0dc2..8a913c3b36a1 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -84,13 +84,13 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} build-additional-packages: "vision audio torchao" diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index 2b840a39a5c2..e6690b104300 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -33,7 +33,7 @@ jobs: with: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge test-matrix: | { include: [ diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 40fb3b8d0c85..758147f5fe18 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -60,7 +60,7 @@ jobs: with: build-environment: linux-jammy-aarch64-py3.10 runner: linux.arm64.m7g.4xlarge - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 test-matrix: | { include: [ { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, diff --git a/.gitignore b/.gitignore index 3b4323051073..d1b3b17445da 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/CMakeLists.txt b/CMakeLists.txt index ca1e4164be9b..113eae7cc583 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,7 +234,17 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_LSAN "Use Leak Sanitizer" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) + +# Track whether USE_CUDA was explicitly set by the user (before option() is called) +# If USE_CUDA is already defined in cache, it means user explicitly set it +if(DEFINED CACHE{USE_CUDA}) + set(_USE_CUDA_EXPLICITLY_SET TRUE) +else() + set(_USE_CUDA_EXPLICITLY_SET FALSE) +endif() + option(USE_CUDA "Use CUDA" ON) + option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9df55ca6acd5..bc0b0fc9bb00 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,7 +18,7 @@ aspects of contributing to PyTorch. - [Python Unit Testing](#python-unit-testing) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) - [Local linting](#local-linting) - - [Running `mypy`](#running-mypy) + - [Running `pyrefly`](#running-pyrefly) - [C++ Unit Testing](#c-unit-testing) - [Run Specific CI Jobs](#run-specific-ci-jobs) - [Merging your Change](#merging-your-change) @@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory. **Prerequisites**: The following packages should be installed with `pip`: - `expecttest` and `hypothesis` - required to run tests -- `mypy` - recommended for linting +- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/) - `pytest` - recommended to run tests more selectively Running ``` @@ -350,15 +350,32 @@ make lint Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner) -#### Running `mypy` +#### Running `pyrefly` -`mypy` is an optional static type checker for Python. We have multiple `mypy` -configs for the PyTorch codebase that are automatically validated against whenever the linter is run. +[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback. + +PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository. + +**Getting Started with Pyrefly:** + +To run type checking on the PyTorch codebase: +```bash +pyrefly check +``` + +For more detailed error information with summaries: +```bash +pyrefly check --summarize-errors +``` + +**Learn More:** +- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options +- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking +- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations See [Guide for adding type annotations to PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch) -for more information on how to set up `mypy` and tackle type annotation -tasks. +for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase. ### C++ Unit Testing diff --git a/SECURITY.md b/SECURITY.md index ed8228af3672..375f94547941 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,7 +1,7 @@ # Security Policy - [**Reporting a Vulnerability**](#reporting-a-vulnerability) - - [**Using Pytorch Securely**](#using-pytorch-securely) + - [**Using PyTorch Securely**](#using-pytorch-securely) - [Untrusted models](#untrusted-models) - [TorchScript models](#torchscript-models) - [Untrusted inputs](#untrusted-inputs) @@ -10,28 +10,28 @@ - [**CI/CD security principles**](#cicd-security-principles) ## Reporting Security Issues -Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch. +Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch. However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new -All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. +All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported: https://www.facebook.com/whitehat -## Using Pytorch Securely -**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). +## Using PyTorch Securely +**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). ### Untrusted models Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources]. **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). -**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. +**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. @@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de ### TorchScript models -TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. +TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. ### Untrusted inputs during training and prediction @@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some ### Data privacy -**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: -- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) -- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). +**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: +- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment) +- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits). ### Using distributed features diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 8b283c417b74..ae762e1def3e 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI) if(USE_CUDA) # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. - set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*") + set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*") file(GLOB_RECURSE fbgemm_genai_native_cuda_cu "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index a354b4191240..6bc321887502 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP() #endif namespace at { -namespace { - /* These const variables defined the fp32 precisions for different backend We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 @@ -41,16 +39,6 @@ namespace { ->rnn */ - C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ - TORCH_WARN_ONCE( - "Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' " - "or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, " - "torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see " - "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" - ); - } -} // namespace - Float32Backend str2backend(const std::string& name) { if (name == "generic") return Float32Backend::GENERIC; @@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional op) const { } else { return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32; } - warn_deprecated_fp32_precision_api(); return allow_tf32_cudnn; } @@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) { setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE); setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE); allow_tf32_cudnn = b; - warn_deprecated_fp32_precision_api(); } void Context::setSDPPriorityOrder(const std::vector& order) { @@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const { "Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", "We suggest only using the new API to set the TF32 flag. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return allow_tf32_new; } @@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const { "Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", "We suggest only using the new API for matmul precision. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return float32_matmul_precision; } @@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) void Context::setFloat32MatmulPrecision(const std::string &s) { auto match = [this](const std::string & s_) { - warn_deprecated_fp32_precision_api(); // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention if (s_ == "highest") { float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h index 9e0b189bdac8..757ef839f965 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h @@ -191,7 +191,7 @@ class Vectorized { auto vals = svreinterpret_u16_bf16(values); vals = sveor_u16_x(ptrue, vals, mask); return svreinterpret_bf16_u16(vals); - }; + } Vectorized round() const; Vectorized tan() const; Vectorized tanh() const; @@ -349,47 +349,47 @@ Vectorized inline Vectorized::frac() const { return convert_float_bfloat16(v1, v2); \ } -DEFINE_BF16_FUNC_VIA_FLOAT(isnan); -DEFINE_BF16_FUNC_VIA_FLOAT(angle); -DEFINE_BF16_FUNC_VIA_FLOAT(acos); -DEFINE_BF16_FUNC_VIA_FLOAT(acosh); -DEFINE_BF16_FUNC_VIA_FLOAT(asin); -DEFINE_BF16_FUNC_VIA_FLOAT(atan); -DEFINE_BF16_FUNC_VIA_FLOAT(atanh); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); -DEFINE_BF16_FUNC_VIA_FLOAT(erf); -DEFINE_BF16_FUNC_VIA_FLOAT(erfc); -DEFINE_BF16_FUNC_VIA_FLOAT(exp); -DEFINE_BF16_FUNC_VIA_FLOAT(exp2); -DEFINE_BF16_FUNC_VIA_FLOAT(expm1); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); -DEFINE_BF16_FUNC_VIA_FLOAT(i0); -DEFINE_BF16_FUNC_VIA_FLOAT(i0e); -DEFINE_BF16_FUNC_VIA_FLOAT(digamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); -DEFINE_BF16_FUNC_VIA_FLOAT(log); -DEFINE_BF16_FUNC_VIA_FLOAT(log2); -DEFINE_BF16_FUNC_VIA_FLOAT(log10); -DEFINE_BF16_FUNC_VIA_FLOAT(log1p); -DEFINE_BF16_FUNC_VIA_FLOAT(sin); -DEFINE_BF16_FUNC_VIA_FLOAT(sinh); -DEFINE_BF16_FUNC_VIA_FLOAT(cos); -DEFINE_BF16_FUNC_VIA_FLOAT(cosh); -DEFINE_BF16_FUNC_VIA_FLOAT(ceil); -DEFINE_BF16_FUNC_VIA_FLOAT(floor); -DEFINE_BF16_FUNC_VIA_FLOAT(round); -DEFINE_BF16_FUNC_VIA_FLOAT(tan); -DEFINE_BF16_FUNC_VIA_FLOAT(tanh); -DEFINE_BF16_FUNC_VIA_FLOAT(trunc); -DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); -DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); -DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); -DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); +DEFINE_BF16_FUNC_VIA_FLOAT(isnan) +DEFINE_BF16_FUNC_VIA_FLOAT(angle) +DEFINE_BF16_FUNC_VIA_FLOAT(acos) +DEFINE_BF16_FUNC_VIA_FLOAT(acosh) +DEFINE_BF16_FUNC_VIA_FLOAT(asin) +DEFINE_BF16_FUNC_VIA_FLOAT(atan) +DEFINE_BF16_FUNC_VIA_FLOAT(atanh) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign) +DEFINE_BF16_FUNC_VIA_FLOAT(erf) +DEFINE_BF16_FUNC_VIA_FLOAT(erfc) +DEFINE_BF16_FUNC_VIA_FLOAT(exp) +DEFINE_BF16_FUNC_VIA_FLOAT(exp2) +DEFINE_BF16_FUNC_VIA_FLOAT(expm1) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot) +DEFINE_BF16_FUNC_VIA_FLOAT(i0) +DEFINE_BF16_FUNC_VIA_FLOAT(i0e) +DEFINE_BF16_FUNC_VIA_FLOAT(digamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter) +DEFINE_BF16_FUNC_VIA_FLOAT(log) +DEFINE_BF16_FUNC_VIA_FLOAT(log2) +DEFINE_BF16_FUNC_VIA_FLOAT(log10) +DEFINE_BF16_FUNC_VIA_FLOAT(log1p) +DEFINE_BF16_FUNC_VIA_FLOAT(sin) +DEFINE_BF16_FUNC_VIA_FLOAT(sinh) +DEFINE_BF16_FUNC_VIA_FLOAT(cos) +DEFINE_BF16_FUNC_VIA_FLOAT(cosh) +DEFINE_BF16_FUNC_VIA_FLOAT(ceil) +DEFINE_BF16_FUNC_VIA_FLOAT(floor) +DEFINE_BF16_FUNC_VIA_FLOAT(round) +DEFINE_BF16_FUNC_VIA_FLOAT(tan) +DEFINE_BF16_FUNC_VIA_FLOAT(tanh) +DEFINE_BF16_FUNC_VIA_FLOAT(trunc) +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma) +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt) +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal) +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow) Vectorized inline Vectorized::operator==( const Vectorized& other) const { diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index aaed43106461..20f235076220 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D #ifndef USE_ROCM at::Half halpha; at::Half hbeta; + uint32_t mask = -1; #endif void * alpha_ptr = α void * beta_ptr = β @@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); if (fp16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = fp16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); if (bf16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = bf16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; - TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( - ltHandle, - computeDesc.descriptor(), - Adesc.descriptor(), - Bdesc.descriptor(), - Cdesc.descriptor(), - Cdesc.descriptor(), - preference.descriptor(), - 1, - &heuristicResult, - &returnedResult)); + // on Blackwell+, we fake a n > 1 matmul when querying heuristics + // to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance +#ifndef USE_ROCM + const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10; +#else + const bool lie_to_cublaslt = false; +#endif + if (lie_to_cublaslt) { + CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T); + CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc); + + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + FakeBdesc.descriptor(), + FakeCdesc.descriptor(), + FakeCdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } else { + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } if (returnedResult == 0) { cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; } diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh index 7081e94837ca..ebbc00438238 100644 --- a/aten/src/ATen/cuda/NumericLimits.cuh +++ b/aten/src/ATen/cuda/NumericLimits.cuh @@ -55,6 +55,14 @@ struct numeric_limits { static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint16_t lowest() { return 0; } + static inline __host__ __device__ uint16_t max() { return UINT16_MAX; } + static inline __host__ __device__ uint16_t lower_bound() { return 0; } + static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } @@ -63,6 +71,14 @@ struct numeric_limits { static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint32_t lowest() { return 0; } + static inline __host__ __device__ uint32_t max() { return UINT32_MAX; } + static inline __host__ __device__ uint32_t lower_bound() { return 0; } + static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } @@ -71,6 +87,21 @@ struct numeric_limits { static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } }; +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return _UI64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; } +#else + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return UINT64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; } +#endif +}; + template <> struct numeric_limits { #ifdef _MSC_VER diff --git a/aten/src/ATen/cuda/cub.h b/aten/src/ATen/cuda/cub.h index 7430edaf8a3d..bca9b1faff52 100644 --- a/aten/src/ATen/cuda/cub.h +++ b/aten/src/ATen/cuda/cub.h @@ -24,7 +24,13 @@ namespace detail { // radix_sort_pairs doesn't interact with value_t other than to copy // the data, so we can save template instantiations by reinterpreting // it as an opaque type. +// We use native integer types for 1/2/4/8-byte values to reduce +// register usage in CUDA kernels. For sizes > 8 fall back to char array. template struct alignas(N) OpaqueType { char data[N]; }; +template <> struct alignas(1) OpaqueType<1> { uint8_t data; }; +template <> struct alignas(2) OpaqueType<2> { uint16_t data; }; +template <> struct alignas(4) OpaqueType<4> { uint32_t data; }; +template <> struct alignas(8) OpaqueType<8> { uint64_t data; }; template void radix_sort_pairs_impl( diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f5d5edb6439a..2fa6bcc6dc9a 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) { } } +static Tensor send_to_meta(const Tensor& self, const Device& device) { + Tensor out_meta; + if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) { + out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device)); + out_meta.unsafeGetTensorImpl()->set_wrapped_number(true); + } else { + out_meta = self.to(device); + } + return out_meta; +} + Tensor mul_zerotensor(const Tensor& self, const Tensor& other) { auto out_device = correct_out_device(self, other); // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta); return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device)); } @@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) { // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta); if (self._is_zerotensor()) { if (other._is_zerotensor()) { @@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::add_Tensor::redispatch( - meta_dks, self.to(device_), other.to(device_), alpha); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha); auto get_out_like = [&] (const Tensor& tensor) { diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 1da245972f0c..fbabba84dbb2 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() { // `_flatten_nd_linear` flattens all but the last dimension of the input tensor // before passing it to linear operation static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) { - const auto input_sizes = input.sym_sizes(); - // can't use -1 in reshape because it errors when a dimension is 0 - c10::SymInt flattened_dim = 1; - for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) { - flattened_dim = flattened_dim * input_sizes[i]; + const auto input_sizes = input.sym_sizes(); + + const auto result_flattened = [&]() -> Tensor { + const auto input_ncols = input_sizes.back(); + const auto input_flattened_nrows = [&]() -> c10::SymInt { + // can't use -1 in reshape because it errors when a dimension is 0 + auto flattened_nrows = c10::SymInt{1}; + for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) { + flattened_nrows *= size; + } + return flattened_nrows; + }(); + + const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols}); + if (weight.layout() == c10::kStrided) { + return at::addmm(bias, input_flattened, weight.t()); + } else { + // weight is sparse, and addmm for sparse expects matmul lhs to be sparse, + // so we transpose the problem. + // NOTE: at::matmul handles (dense @ sparse) similarly. + const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1); + return at::addmm(bias_t, weight, input_flattened.t()).t(); } - auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)}); - const auto result = at::addmm(bias, inp_reshape, weight.t()); - auto new_size = input_sizes.slice(0, input_sizes.size() - 1); - c10::SymDimVector sizes_vec(new_size.begin(), new_size.end()); - sizes_vec.push_back(result.sym_size(1)); - return result.view_symint(sizes_vec); + }(); + + // Unflatten flattened row dims + auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()}; + result_sizes.back() = result_flattened.sym_size(1); + return result_flattened.view_symint(result_sizes); } @@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { - // Also hit the fused path for contiguous 3D input, if not using xla + + const auto is_bias_likely_fusable = ( + bias->defined() && + // cuBLASLt: will fuse in the epilogue without copies + // when input/weight/bias are all strided. + // When weight is not strided, bias will not be fused, + // but we can still dispatch here to avoid at::matmul + // path which will probably use a very similar + // flattening optimization. + ((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false()) + ); + if (is_bias_likely_fusable && !input.is_xla()) { + // Also hit the fused path for contiguous nD input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. - bool is_contiguous = input.is_contiguous_or_false(); - if (is_contiguous && input_dim == 3) { - return _flatten_nd_linear(input, weight, *bias); - } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { + if (input.is_contiguous_or_false()) { return _flatten_nd_linear(input, weight, *bias); - } else if (parseLinearFlatten3d() && input_dim == 3) { + } else if (parseLinearFlatten3d()) { // If user forces flattening via env var const Tensor input_cont = input.contiguous(); return _flatten_nd_linear(input_cont, weight, *bias); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822d..0079a530b3d0 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,37 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + + auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0)); + auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0)); + + if (cond1 || cond2) { + if (cond1) { + start = start + cur_size; + } + + TORCH_SYM_CHECK( + start.sym_le(cur_size - length), + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + return at::slice_symint(self, dim, start, start + length, 1); + } + + // Unbacked start handling! + + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1749,28 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) { + return at::slice_symint(self, dim, start, end, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in @@ -1736,8 +1784,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - int64_t st = start.item(); - return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); + c10::SymInt st = start.item().toSymInt(); + return at::narrow_symint(self, dim, std::move(st), std::move(length)); } std:: diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 7587988528eb..73f8c136794c 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -293,7 +293,7 @@ struct ComputeLocationBase { , empty(size <= 0) {} inline Vec unnormalize(const Vec &in) const { - return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5); + return (in + Vec(static_cast(1))) * Vec(scaling_factor) - Vec(static_cast(0.5)); } inline Vec clip_coordinates(const Vec &in) const { @@ -831,7 +831,7 @@ struct ApplyGridSample(-0.75)); ApplyGridSample(const TensorAccessor& input) : inp_H(input.size(2)) diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index c7eaa802af12..c5dbf05039eb 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return minimum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -199,7 +200,7 @@ void aminmax_allreduce_kernel( } ); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] { using Vec = Vectorized>; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( @@ -214,7 +215,7 @@ void aminmax_allreduce_kernel( [=](Vec a, Vec b) -> Vec { return minimum(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); } } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2e6293650194..3bad49a32d98 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -347,34 +348,35 @@ struct MinValuesOps: public at::native::MinOps { }; void min_values_kernel_impl(TensorIterator& iter) { - if (iter.dtype() == kLong) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - using scalar_t = int64_t; - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + if (iter.dtype() == kLong || iter.dtype() == kUInt64) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); + }), kLong, kUInt64); return; } - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, static_cast(upper_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); }, [](Vectorized a, Vectorized b) { return maximum(a, b); }, lower_bound()); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void argmax_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index c479e1610cbe..22c85735ad6a 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -106,7 +107,7 @@ void min_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -128,7 +129,7 @@ void min_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void max_kernel_impl( @@ -139,7 +140,7 @@ void max_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -161,7 +162,7 @@ void max_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void aminmax_kernel( @@ -186,7 +187,7 @@ void aminmax_kernel( return; } - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] { compare_base_kernel(min_result, max_result, self, wrap_dim, keepdim, [&] ( scalar_t* min_result_data, scalar_t* max_result_data, const scalar_t* self_data, auto self_dim_stride) { @@ -209,7 +210,7 @@ void aminmax_kernel( *max_result_data = max_number; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half); } void where_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index f64eb317d0cc..18ae048cfc96 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -22,6 +22,9 @@ #include #include #include +#ifdef USE_ROCM +#include +#endif #include #ifdef USE_FBGEMM_GENAI @@ -666,12 +669,19 @@ std::optional out_dtype) { // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm bool use_fast_path = false; + if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) { + use_fast_path = true; + } #endif const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); if (use_fast_path) { // fast path, no d2h sync needed +#ifndef USE_ROCM at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +#else + at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#endif } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 927af661396c..db85f62c8d12 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -74,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co char* const out_ptr = static_cast(iter.data_ptr(0)); char* const in_ptr = static_cast(iter.data_ptr(1)); - if (is_gather_like && num_indices==1) { const size_t element_size = iter.element_size(0); constexpr size_t alignment = 16; @@ -84,16 +82,9 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co auto ind_dim_size = index_size[0]; auto inp_stride_bytes = index_stride[0]; auto out_stride_bytes = iter.strides(0)[1]; - // avoid grid overflow in the fast kernel - const int64_t vec_chunks = ceil_div(slice_size, alignment); - const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd); - const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - // if it's an eligible grid we use the fast path, otherwise default to slower path - if (blocks_per_slice_upper <= max_grid_y) { - at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, - slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); - return; - } + at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, + slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); + return; } } diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index 8343c6041895..1e998251dd7b 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -13,11 +13,12 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, if (allow_neg_indices) { ind = (ind < 0) ? ind + ind_dim_size : ind; } - CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind); - int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits - if (off >= slice_size) return; - auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); - at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"); + // off is guaranteed to be within int32 limits + for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) { + auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); + at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + } } @@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int auto num_threads = at::round_up( at::ceil_div(slice_size_in_bytes, Alignment), static_cast(C10_WARP_SIZE)); - dim3 grid = {static_cast(num_ind), static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1}; + uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + grid_y = std::min(static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y); + dim3 grid = {static_cast(num_ind), grid_y, 1}; auto block = std::min(max_num_threads, num_threads); vectorized_gather_kernel<<>>(out, inp, idx, num_ind, slice_size_in_bytes, ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices); diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu index cdd5daab2d98..0b7823863047 100644 --- a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -28,22 +29,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { } void aminmax_allreduce_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] { _min_max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void aminmax_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinMaxOps{}, thrust::pair( at::numeric_limits::upper_bound(), at::numeric_limits::lower_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu index e8d1e88ebb3e..bcbc4c035994 100644 --- a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -33,27 +34,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) { } void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { + AT_DISPATCH_V2( + iter.dtype(), "max_values_cuda", AT_WRAP([&]() { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "max_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MaxOps{}, thrust::pair( at::numeric_limits::lower_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu index e01ca6c88ebc..0006a24dbc46 100644 --- a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinOps{}, thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 0d2963874abb..9065d7992936 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -59,6 +59,24 @@ // forward declare class cublasCommonArgs; +#ifndef _WIN32 +namespace fbgemm_gpu { + +// NOTE(slayton58): FBGemm_GPU kernels come from within the FBGemm repo. +// To update supported ops means a submodule bump, which is.. painful. Instead, we +// can simply forward-declare the methods we want to use.. Works at least as a short-term +// thing, but should still be fixed somewhere/somehow. +at::Tensor f4f4bf16( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + std::optional, + bool use_mx); + +} // namespace fbgemm_gpu +#endif + using at::blas::ScalingType; using at::blas::SwizzleType; @@ -1087,26 +1105,47 @@ _scaled_mxfp4_mxfp4( const std::optional& bias, const c10::ScalarType out_dtype, Tensor& out) { -#ifndef USE_ROCM - TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); -#endif +#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)) + TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only"); +#else // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); - auto scale_a_elems = ceil_div(2 * mat_a.size(0), 32) * mat_a.size(1); - auto scale_b_elems = ceil_div(2 * mat_b.size(1), 32) * mat_b.size(0); + // Packed FP4 format means actual-K = 2 * reported-K -- adjust + auto K_multiplier = 2; +#ifdef USE_ROCM + // AMD + auto scale_a_elems = ceil_div(K_multiplier * mat_a.size(0), 32) * mat_a.size(1); + auto scale_b_elems = ceil_div(K_multiplier * mat_b.size(1), 32) * mat_b.size(0); +#else + // NVIDIA + auto scale_a_elems = round_up(mat_a.size(0), 128) * round_up(ceil_div(K_multiplier * mat_a.size(1), 32), 4); + auto scale_b_elems = round_up(mat_b.size(1), 128) * round_up(ceil_div(K_multiplier * mat_b.size(0), 32), 4); +#endif TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); +#ifdef USE_ROCM + // AMD + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)"); +#else + // NVIDIA + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); +#endif + TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), "For Blockwise scaling both scales should be contiguous"); TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); +#ifdef USE_ROCM + // AMD auto scaling_choice_a = ScalingType::BlockWise1x32; auto scaling_choice_b = ScalingType::BlockWise1x32; @@ -1121,11 +1160,30 @@ _scaled_mxfp4_mxfp4( TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, "Block-wise scaling only supports BFloat16 or Half output types"); -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); #endif return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); +#else + // NVIDIA + // NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor, + // but we have one we need to use. Two clear options are to copy into + // our output (slow), or use a move-assignment-operator (faster). + // However, the compiler can complain about the explicit move preventing + // copy elision because the return from f4f4bf16 is a temporary object. + // So we don't explicitly move, and trust the compiler here... + // In the longer term this should be fixed on the FBGemm side. + out = fbgemm_gpu::f4f4bf16( + mat_a, + mat_b.transpose(-2, -1), + scale_a, + scale_b, + std::nullopt, /* global_scale */ + true /* use_mx */ + ); + + return out; +#endif +#endif } Tensor& @@ -1250,17 +1308,20 @@ _scaled_mm_cuda_v2_out( mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")"); } + // Handle fp4 packed-K dimension + int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1; + TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1], " but got ", bias->numel()); TORCH_CHECK_VALUE( - mat_a.sizes()[1] % 16 == 0, + K_multiplier * mat_a.sizes()[1] % 16 == 0, "Expected trailing dimension of mat1 to be divisible by 16 ", "but got mat1 shape: (", mat_a.sizes()[0], "x", - mat_a.sizes()[1], + K_multiplier * mat_a.sizes()[1], ")."); - TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", + TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", mat_b.sizes()[1], ") must be divisible by 16"); // TODO(slayton): Existing checks, not sure if they should really be here. diff --git a/aten/src/ATen/native/hip/ck_group_gemm.h b/aten/src/ATen/native/hip/ck_group_gemm.h new file mode 100644 index 000000000000..c50307c9f8ea --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace hip { +namespace detail { +void group_gemm_ck( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + const std::optional& bias, + at::Tensor& out); + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip new file mode 100644 index 000000000000..c436ad660c1c --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -0,0 +1,462 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at { +namespace hip { +namespace detail { + +namespace CkTypes { + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F32 = float; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; +} + +template +using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< + ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, + DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, + CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + 1, 1, + S<1,32,1,8>, 4 +>; + +template +void launch_grouped_bgemm_ck_impl_dispatch( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + at::Tensor& out) +{ + using DeviceOp = GroupedGemmKernel; + using PassThrough = CkTypes::PassThrough; + + std::vector gemm_descs; + std::vector p_a_ptrs, p_b_ptrs; + std::vector p_e_ptrs; + // Note: d_ptrs will be resized after we populate the other vectors + + const int mat_a_dim = mat_a.dim(); + const int mat_b_dim = mat_b.dim(); + + const char* a_ptr_base = reinterpret_cast(mat_a.data_ptr()); + const char* b_ptr_base = reinterpret_cast(mat_b.data_ptr()); + char* out_ptr_base = reinterpret_cast(out.data_ptr()); + const size_t a_element_size = mat_a.element_size(); + const size_t b_element_size = mat_b.element_size(); + const size_t out_element_size = out.element_size(); + + // for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses. + if (mat_a_dim == 2 && mat_b_dim == 2) { + // 2D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + const int M = mat_a.size(0); // number of rows in A + const int N = mat_b.size(1); // number of columns in B + const int K = mat_a.size(1); // columns in A == rows in B + // for 2d*2d input, output is 3d. + // for each group, A columns (K) are sliced. M and N dimensions are not sliced. + for (int i = 0; i < num_groups; ++i) { + int start_k = (i == 0) ? 0 : offs_accessor[i-1]; + int end_k = offs_accessor[i]; + int k = end_k - start_k; + + //K dimension are sliced, hence select stride(1) always. + //K dimension is always dimension 1, regardless of memory layout (row/column major) + const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size; + const void* group_b_ptr; + int ldb; + + if (std::is_same::value) { + // Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size; + // Leading dimension = distance between rows = stride(0) + ldb = mat_b.stride(0); + } else { + // Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size; + // Leading dimension = distance between columns = stride(1) + ldb = mat_b.stride(1); + } + + // Calculate output pointer for group i in 3D tensor [num_groups, M, N] + // stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + int lda, ldc; + if (std::is_same::value) { + // Row-major A [M,K]: leading dimension = distance between rows = stride(0) + lda = mat_a.stride(0); + } else { + // Column-major A [M,K]: leading dimension = distance between columns = stride(1) + lda = mat_a.stride(1); + } + // Output is always row-major in 3D tensor [num_groups, M, N] + // Leading dimension for each group's [M,N] slice = stride(1) = N + ldc = out.stride(1); + size_t output_group_bytes = M * N * out_element_size; + void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes; + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(k), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 2 && mat_b_dim == 3) { + // 2D*3D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + + // 2d*3d input, output is 2d. + // A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n] + // Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B + const int K = mat_a.size(1); // columns in A + // For 2D-3D case: The output determines N (result width) + const int N = out.size(1); // N is the width of the output tensor + + for (int i = 0; i < num_groups; ++i) { + int start_m = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_m = offs_accessor[i]; + int m = end_m - start_m; + + // Skip zero-sized groups but continue processing subsequent groups + if (m <= 0) { + continue; + } + + // Select A rows for group i: skip start_m rows + const void* group_a_ptr; + int lda; + if (std::is_same::value) { + // Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + lda = mat_a.stride(0); // distance between rows + } else { + // Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows) + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + + // Detect stride pattern for A tensor to determine appropriate lda calculation + bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0)); + + if (a_is_strided_tensor) { + // For strided A tensors: stride(0) gives the actual leading dimension + lda = mat_a.stride(0); + } else { + // For non-strided A tensors: use the M dimension (total rows) + lda = mat_a.size(0); // Total M dimension for column-major layout + } + } + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + int ldb; + + if (std::is_same::value) { + // Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed + ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N] + } else { + // Detect stride pattern to determine appropriate ldb calculation + bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2)); + + if (is_strided_tensor) { + // For strided tensors: stride(2) gives the actual leading dimension + ldb = mat_b.stride(2); + } else { + // For non-strided tensors: use the N dimension + ldb = mat_b.size(1); + } + } + + // Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N] + void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size; + int ldc = out.stride(0); // distance between rows in output (should be N for 2D case) + + gemm_descs.push_back({ + static_cast(m), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 3) { + // 3d*3d input, output is 3d - batched matrix multiplication + // A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n] + // Each batch is processed as a separate GEMM operation + const int batch_size = mat_a.size(0); + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed) + + // Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout + int N; + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + N = mat_b.size(2); + } else if (mat_b.size(2) == K) { + // B is [batch, n, k] - transposed layout + N = mat_b.size(1); + } else { + TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[", + batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]"); + } + + for (int i = 0; i < batch_size; ++i) { + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + + // Select output batch for group i: Output[i, :, :] + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + + int lda, ldb, ldc; + + if (std::is_same::value) { + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + } else { + // Column-major A: leading dimension = distance between columns = stride(2) + lda = mat_a.stride(2); + } + + if (std::is_same::value) { + // Row-major B: leading dimension = distance between rows + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(1); // stride between K rows + } else { + // B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM + ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n]) + } + } else { + // Column-major B: leading dimension = distance between columns + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(2); // stride between N columns + } else { + // B is [batch, n, k] - transposed layout + ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]β†’[k,n]) + } + } + + // Output is typically row-major: leading dimension = distance between rows = stride(1) + ldc = out.stride(1); + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 2) { + // 3D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + // 3d*2d input, output is 3d. + // A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both) + // Offset divides N dimension of B, each group gets different slice of B and different batch of A + const int batch_size = mat_a.size(0); // n_groups + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A + + // For row-major A and B case: B should be [K, total_N] + const int total_N = mat_b.size(1); // B is [K, total_N] for row-major + + for (int i = 0; i < num_groups; ++i) { + int start_n = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_n = offs_accessor[i]; + int n = end_n - start_n; + + // Skip zero-sized groups but continue processing subsequent groups + if (n <= 0) { + continue; + } + + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B slice for group i: B[:, start_n:end_n] (B[K, total_N]) + const void* group_b_ptr; + int ldb; + + // Check if B is row-major or column-major + if (std::is_same::value) { + // Row-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(0); // distance between rows (should be total_N) + } else { + // Column-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(1); // distance between columns (should be K) + } + + // Select output slice for group i: Output[:, start_n:end_n] + void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size; + + int lda, ldc; + + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + // Output is row-major: leading dimension = distance between rows = stride(0) + ldc = out.stride(0); + + gemm_descs.push_back({ + static_cast(M), + static_cast(n), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim); + } + + TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups"); + + // Initialize d_ptrs with the correct size + std::vector> d_ptrs(p_a_ptrs.size()); + + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = nullptr; + void* ws_buf = nullptr; + + hipMalloc(&gemm_arg_buf, arg_buf_size); + hipMalloc(&ws_buf, ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::hip::getCurrentHIPStream(); + invoker.Run(argument, {stream}); + hipFree(gemm_arg_buf); + hipFree(ws_buf); +} + +void group_gemm_ck( + const at::Tensor& input_a, + const at::Tensor& input_b_colmajor, + const std::optional& offs, + const std::optional& /*bias*/, + at::Tensor& out) +{ + // Detect if input_a is row-major based on stride pattern + bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1); + bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1); + // Ensure tensor A is row-major and contiguous if not already + at::Tensor mat_a = input_a; + if (!a_row_major) { + // If A is not row-major, make it contiguous (row-major) + mat_a = input_a.contiguous(); + } + // Force tensor B to be column-major using double transpose trick + // This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape + at::Tensor mat_b = input_b_colmajor; + if (!b_col_major) { + mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1); + } + + // For 3D tensors, check the last dimension stride for row-major detection + a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1); + bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1); + + if (mat_a.dtype() == at::kBFloat16) { + // bf16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kHalf) { + // fp16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kFloat) { + // fp32 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype"); + } + +} + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index c995b8fc237f..f0bbcdabfa5c 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -212,17 +212,12 @@ loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({})); TORCH_CHECK(loss.is_mps()); - Tensor loss_squeezed = loss.squeeze(); - Tensor input_squeezed = input.squeeze(); - Tensor target_squeezed = target.squeeze(); - @autoreleasepool { - std::string key = - op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight}); + std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed); + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); MPSGraphTensor* bceLossUnweighted = nil; // if grad_output is defined, then it's a backward pass @@ -252,12 +247,12 @@ newCachedGraph->gradInputTensor = bceLoss; } } else { - newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); + newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size()); } }); - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed); - Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); + Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss); NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl index 180442b4b09a..fecce634ec08 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX") # Shared by internal and OSS BUCK def define_qnnpack(third_party, labels = []): @@ -21,7 +21,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -82,7 +82,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -129,7 +129,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -184,7 +184,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -236,7 +236,7 @@ def define_qnnpack(third_party, labels = []): ], ), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", ], @@ -291,7 +291,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ("include", "*.h"), ]), - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -398,7 +398,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -465,7 +465,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", "-Wno-unused-command-line-argument", @@ -525,7 +525,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py index a98628944b3e..3e94a9cd7f3d 100644 --- a/benchmarks/operator_benchmark/pt/addmm_test.py +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -53,10 +53,8 @@ def forward(self, input_one, mat1, mat2): return torch.addmm(input_one, mat1, mat2) -op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark) -op_bench.generate_pt_gradient_test( - addmm_long_configs + addmm_long_configs, AddmmBenchmark -) +op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark) +op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark) """Mircobenchmark for addbmm operator.""" @@ -107,9 +105,7 @@ def forward(self, input_one, batch1, batch2): ) op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark) -op_bench.generate_pt_gradient_test( - addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark -) +op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/buckbuild.bzl b/buckbuild.bzl index 4c1affd10e1b..9f18ad4849dd 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -8,7 +8,7 @@ load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "IOS", "MACOSX") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build") @@ -1090,7 +1090,7 @@ def define_buck_targets( srcs = [ "caffe2/core/common.cc", ], - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = get_pt_compiler_flags(), labels = labels, # @lint-ignore BUCKLINT link_whole diff --git a/build_variables.bzl b/build_variables.bzl index 70121e19d809..258e739300c1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1025,6 +1025,7 @@ libtorch_python_core_sources = [ libtorch_python_distributed_core_sources = [ "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/python_comm_hook.cpp", + "torch/csrc/distributed/c10d/python_callback_work.cpp", ] libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [ diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 72e72f49a5e4..107530e9e28a 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -59,6 +59,9 @@ constexpr DispatchKeySet nested_dispatch_keyset = {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); +constexpr DispatchKeySet functorch_batched_dispatch_keyset = + DispatchKeySet(DispatchKey::FuncTorchBatched); + DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { @@ -77,6 +80,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { return backend_dispatch_keyset; case DispatchKey::CompositeExplicitAutogradNonFunctional: return non_functional_backend_dispatch_keyset; + case DispatchKey::FuncTorchBatchedDecomposition: + return functorch_batched_dispatch_keyset; default: return DispatchKeySet(t); } diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d2740..48c407b8b069 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b1..a27a28a5bf8a 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 3046259b48a3..5414d838cd8c 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -106,6 +106,9 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } else if (key == "graph_capture_record_stream_reuse") { i = parseGraphCaptureRecordStreamReuse(tokenizer, i); used_native_specific_option = true; + } else if (key == "per_process_memory_fraction") { + i = parsePerProcessMemoryFraction(tokenizer, i); + used_native_specific_option = true; } else { const auto& keys = c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); @@ -146,6 +149,18 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse( return i; } +double CUDAAllocatorConfig::parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { + tokenizer.checkToken(++i, ":"); + double val_env = tokenizer.toDouble(++i); + TORCH_CHECK_VALUE( + val_env >= 0.0 && val_env <= 1.0, + "per_process_memory_fraction is invalid, set it in [0.0, 1.0]"); + m_per_process_memory_fraction = val_env; + return i; +} + size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i) { diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index d61f69467a2d..4e6097a406bc 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -61,6 +61,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_graph_capture_record_stream_reuse; } + static double per_process_memory_fraction() { + return instance().m_per_process_memory_fraction; + } + /** Pinned memory allocator settings */ static bool pinned_use_cuda_host_register() { return instance().m_pinned_use_cuda_host_register; @@ -152,7 +156,8 @@ class C10_CUDA_API CUDAAllocatorConfig { "pinned_use_hip_host_register", "graph_capture_record_stream_reuse", "pinned_reserve_segment_size_mb", - "pinned_num_register_threads"}; + "pinned_num_register_threads", + "per_process_memory_fraction"}; return keys; } @@ -177,6 +182,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parseGraphCaptureRecordStreamReuse( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); + double parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); std::atomic m_pinned_num_register_threads{1}; std::atomic m_pinned_reserve_segment_size_mb{0}; @@ -189,6 +197,7 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_release_lock_on_cudamalloc{false}; std::atomic m_pinned_use_cuda_host_register{false}; std::atomic m_graph_capture_record_stream_reuse{false}; + std::atomic m_per_process_memory_fraction{1.0}; }; // Keep this for backwards compatibility diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 091e580f9581..d66c3a16c000 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1100,7 +1100,7 @@ class RingBuffer { } // anonymous namespace } // namespace Native -static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { +static std::string reportProcessMemoryInfo(const cudaDeviceProp& prop) { #ifdef PYTORCH_C10_DRIVER_API_SUPPORTED void* nvml_handle = DriverAPI::get_nvml_handle(); if (!nvml_handle) { @@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { return true; }(); - cudaDeviceProp prop{}; - C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - // NOLINTNEXTLINE(*-c-arrays) char pci_id[80]; snprintf( @@ -1215,14 +1212,16 @@ class DeviceCachingAllocator { // record used memory. size_t total_allocated_memory = 0; - size_t allowed_memory_maximum = 0; + cudaDeviceProp device_prop; + + // maximum amount of memory that device is allowed to + // allocate. This is set iff memory fraction is less than 1 + std::optional allowed_memory_maximum{std::nullopt}; // all live expandable segments std::vector expandable_segments_; std::vector devices_with_peer_access_; - bool set_fraction = false; - bool record_history = false; std::atomic context_recorder_; @@ -1264,6 +1263,9 @@ class DeviceCachingAllocator { : device_id(id), large_blocks(/*small=*/false), small_blocks(/*small=*/true) { + C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id)); + + setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction()); stats.max_split_size = static_cast(AcceleratorAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); @@ -1399,7 +1401,7 @@ class DeviceCachingAllocator { if (!block_found) { // Do garbage collection if the flag is set. if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); @@ -1456,11 +1458,12 @@ class DeviceCachingAllocator { C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); std::string allowed_info; - if (set_fraction) { - allowed_info = format_size(allowed_memory_maximum) + " allowed; "; + if (allowed_memory_maximum.has_value()) { + allowed_info = + format_size(allowed_memory_maximum.value()) + " allowed; "; } - std::string proc_info = reportProcessMemoryInfo(device_id); + std::string proc_info = reportProcessMemoryInfo(device_prop); record_trace( TraceEntry::OOM, @@ -1518,7 +1521,7 @@ class DeviceCachingAllocator { for (const auto& obs : observers_local) { obs(device_id, alloc_size, - set_fraction ? allowed_memory_maximum : device_total, + allowed_memory_maximum.value_or(device_total), device_free); } @@ -2015,25 +2018,26 @@ class DeviceCachingAllocator { /** get memory fraction limiting maximum allocated memory **/ double getMemoryFraction() { - if (!set_fraction) { + if (!allowed_memory_maximum.has_value()) { return 1.0; } - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - return static_cast(allowed_memory_maximum) / - static_cast(device_total); + return static_cast(allowed_memory_maximum.value()) / + static_cast(device_prop.totalGlobalMem); } /** set memory fraction to limit maximum allocated memory **/ void setMemoryFraction(double fraction) { - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - allowed_memory_maximum = - static_cast(fraction * static_cast(device_total)); - set_fraction = true; + TORCH_CHECK( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within [0, 1]."); + allowed_memory_maximum = std::nullopt; + if (fraction < 1.0) { + allowed_memory_maximum = static_cast( + fraction * static_cast(device_prop.totalGlobalMem)); + } } /** get expandable segment size for all the streams on device **/ @@ -3010,7 +3014,7 @@ class DeviceCachingAllocator { BlockPool& pool = *p.pool; if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; @@ -3083,7 +3087,7 @@ class DeviceCachingAllocator { size_t gc_threshold = static_cast( AcceleratorAllocatorConfig::garbage_collection_threshold() * - static_cast(allowed_memory_maximum)); + static_cast(allowed_memory_maximum.value())); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { return; @@ -3161,8 +3165,8 @@ class DeviceCachingAllocator { bool active_pool = p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); - if (set_fraction && - total_allocated_memory + size > allowed_memory_maximum) { + if (allowed_memory_maximum.has_value() && + total_allocated_memory + size > allowed_memory_maximum.value()) { p.err = cudaErrorMemoryAllocation; return false; // Temporarily disable checkpointing & cudagraphs internally @@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); return device_allocator[device]->getMemoryFraction(); } @@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - TORCH_CHECK( - 0 <= fraction && fraction <= 1, - "invalid fraction:", - fraction, - ". Please set within [0, 1]."); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); device_allocator[device]->setMemoryFraction(fraction); } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index fbe5dab18e0a..8fee00dd621d 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 93bce51f1b9d..674eb00035c5 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -427,7 +427,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { // on the current device each later call sees. void init(int dev_count) override { static bool called = [](int dev_count) { - ; // Are there external guarantees init will be called before // any of the allocator's other functions? // std::lock_guard lk(general_mutex); diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 64605f515359..1311867ef797 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -40,200 +41,99 @@ namespace c10 { /// /// This is intended to be trivially copyable, so it should be passed by /// value. +/// +/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct +/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of +/// the underlying constexpr calls, we rely on apparent-type dispatch for +/// inheritance. This should be fine because their memory format is the same, +/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods. +/// However, you should prefer to use ArrayRef when possible, because its use +/// of TORCH_CHECK will lead to better user-facing error messages. template -class ArrayRef final { +class ArrayRef final : public HeaderOnlyArrayRef { public: - using iterator = const T*; - using const_iterator = const T*; - using size_type = size_t; - using value_type = T; - - using reverse_iterator = std::reverse_iterator; - - private: - /// The start of the array, in an external buffer. - const T* Data; - - /// The number of elements. - size_type Length; - - void debugCheckNullptrInvariant() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - Data != nullptr || Length == 0, - "created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal"); - } - - public: - /// @name Constructors + /// @name Constructors, all inherited from HeaderOnlyArrayRef except for + /// SmallVector. As inherited constructors won't work with class template + /// argument deduction (CTAD) until C++23, we add deduction guides after + /// the class definition to enable CTAD. /// @{ - /// Construct an empty ArrayRef. - /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} - - /// Construct an ArrayRef from a single element. - // TODO Make this explicit - constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} - - /// Construct an ArrayRef from a pointer and length. - constexpr ArrayRef(const T* data, size_t length) - : Data(data), Length(length) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a range. - constexpr ArrayRef(const T* begin, const T* end) - : Data(begin), Length(end - begin) { - debugCheckNullptrInvariant(); - } + using HeaderOnlyArrayRef::HeaderOnlyArrayRef; /// Construct an ArrayRef from a SmallVector. This is templated in order to /// avoid instantiating SmallVectorTemplateCommon whenever we /// copy-construct an ArrayRef. + /// NOTE: this is the only constructor that is not inherited from + /// HeaderOnlyArrayRef. template /* implicit */ ArrayRef(const SmallVectorTemplateCommon& Vec) - : Data(Vec.data()), Length(Vec.size()) { - debugCheckNullptrInvariant(); - } - - template < - typename Container, - typename U = decltype(std::declval().data()), - typename = std::enable_if_t< - (std::is_same_v || std::is_same_v)>> - /* implicit */ ArrayRef(const Container& container) - : Data(container.data()), Length(container.size()) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a std::vector. - // The enable_if stuff here makes sure that this isn't used for - // std::vector, because ArrayRef can't work on a std::vector - // bitfield. - template - /* implicit */ ArrayRef(const std::vector& Vec) - : Data(Vec.data()), Length(Vec.size()) { - static_assert( - !std::is_same_v, - "ArrayRef cannot be constructed from a std::vector bitfield."); - } - - /// Construct an ArrayRef from a std::array - template - /* implicit */ constexpr ArrayRef(const std::array& Arr) - : Data(Arr.data()), Length(N) {} - - /// Construct an ArrayRef from a C array. - template - // NOLINTNEXTLINE(*c-arrays*) - /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} - - /// Construct an ArrayRef from a std::initializer_list. - /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) - : Data( - std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) - : std::begin(Vec)), - Length(Vec.size()) {} + : HeaderOnlyArrayRef(Vec.data(), Vec.size()) {} /// @} - /// @name Simple Operations + /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef /// @{ - constexpr iterator begin() const { - return Data; - } - constexpr iterator end() const { - return Data + Length; - } - - // These are actually the same as iterator, since ArrayRef only - // gives you const iterators. - constexpr const_iterator cbegin() const { - return Data; - } - constexpr const_iterator cend() const { - return Data + Length; - } - - constexpr reverse_iterator rbegin() const { - return reverse_iterator(end()); - } - constexpr reverse_iterator rend() const { - return reverse_iterator(begin()); - } - - /// Check if all elements in the array satisfy the given expression - constexpr bool allMatch(const std::function& pred) const { - return std::all_of(cbegin(), cend(), pred); - } - - /// empty - Check if the array is empty. - constexpr bool empty() const { - return Length == 0; - } - - constexpr const T* data() const { - return Data; - } - - /// size - Get the array size. - constexpr size_t size() const { - return Length; - } - /// front - Get the first element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& front() const { TORCH_CHECK( - !empty(), "ArrayRef: attempted to access front() of empty list"); - return Data[0]; + !this->empty(), "ArrayRef: attempted to access front() of empty list"); + return this->Data[0]; } /// back - Get the last element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& back() const { - TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list"); - return Data[Length - 1]; - } - - /// equals - Check for element-wise equality. - constexpr bool equals(ArrayRef RHS) const { - return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + TORCH_CHECK( + !this->empty(), "ArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; } /// slice(n, m) - Take M elements of the array starting at element N + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N, size_t M) const { TORCH_CHECK( - N + M <= size(), + N + M <= this->size(), "ArrayRef: invalid slice, N = ", N, "; M = ", M, "; size = ", - size()); - return ArrayRef(data() + N, M); + this->size()); + return ArrayRef(this->data() + N, M); } /// slice(n) - Chop off the first N elements of the array. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N) const { TORCH_CHECK( - N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); - return slice(N, size() - N); + N <= this->size(), + "ArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); // should this slice be this->slice? } /// @} /// @name Operator Overloads /// @{ - constexpr const T& operator[](size_t Index) const { - return Data[Index]; - } /// Vector compatibility + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& at(size_t Index) const { TORCH_CHECK( - Index < Length, + Index < this->Length, "ArrayRef: invalid index Index = ", Index, "; Length = ", - Length); - return Data[Index]; + this->Length); + return this->Data[Index]; } /// Disallow accidental assignment from a temporary. @@ -253,16 +153,48 @@ class ArrayRef final { std::enable_if_t, ArrayRef>& operator=( std::initializer_list) = delete; - /// @} - /// @name Expensive Operations - /// @{ - std::vector vec() const { - return std::vector(Data, Data + Length); - } - /// @} }; +/// Deduction guides for ArrayRef to support CTAD with inherited constructors +/// These mirror the constructors inherited from HeaderOnlyArrayRef +/// @{ + +// Single element constructor +template +ArrayRef(const T&) -> ArrayRef; + +// Pointer and length constructor +template +ArrayRef(const T*, size_t) -> ArrayRef; + +// Range constructor (begin, end) +template +ArrayRef(const T*, const T*) -> ArrayRef; + +// Generic container constructor (anything with .data() and .size()) +template +ArrayRef(const Container&) -> ArrayRef< + std::remove_pointer_t().data())>>; + +// std::vector constructor +template +ArrayRef(const std::vector&) -> ArrayRef; + +// std::array constructor +template +ArrayRef(const std::array&) -> ArrayRef; + +// C array constructor +template +ArrayRef(const T (&)[N]) -> ArrayRef; + +// std::initializer_list constructor +template +ArrayRef(const std::initializer_list&) -> ArrayRef; + +/// @} + template std::ostream& operator<<(std::ostream& out, ArrayRef list) { int i = 0; diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0e86e826405c..e1cc43350b2b 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1307,7 +1307,7 @@ endif() if(USE_MKLDNN_ACL) find_package(ACL REQUIRED) - target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS}) + target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS}) endif() target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index 944cd4d8d257..0bdfe275d9c0 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -26,7 +26,7 @@ find_library(Gloo_CUDA_LIBRARY # if Gloo + HIP is desired, Gloo_HIP_LIBRARY # needs to be linked to desired target find_library(Gloo_HIP_LIBRARY - NAMES gloo_hiop + NAMES gloo_hip DOC "Gloo's HIP support/code" ) diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 218c50a69c6f..bc8855d23e61 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -28,6 +28,15 @@ endif() # Find CUDA. find_package(CUDA) if(NOT CUDA_FOUND) + # If user explicitly set USE_CUDA=1, error out instead of falling back + if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA) + message(FATAL_ERROR + "PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. " + "Please check your CUDA installation, ensure CUDA toolkit is installed, " + "and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. " + "If you want to build without CUDA, please set USE_CUDA=0.") + endif() + message(WARNING "PyTorch: CUDA cannot be found. Depending on whether you are building " "PyTorch or a PyTorch dependent library, the next warning / error will " diff --git a/docs/source/complex_numbers.md b/docs/source/complex_numbers.md index 610f9a06615a..095401879f09 100644 --- a/docs/source/complex_numbers.md +++ b/docs/source/complex_numbers.md @@ -45,7 +45,7 @@ supported for complex tensors. ## Transition from the old representation Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)` -can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex` +can easily switch to using the complex tensors in their code using {func}`torch.view_as_complex` and {func}`torch.view_as_real`. Note that these functions don’t perform any copy and return a view of the input tensor. @@ -140,7 +140,7 @@ through the same optimizer on the {func}`torch.view_as_real` equivalent of the c `real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers -and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). +and capturable vs default optimizers. For more details, see [numerical accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their `p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 1c9d374b8ab0..ca1fe3b5e909 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective .. autofunction:: new_group ``` +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.shrink_group +``` + ```{eval-rst} .. autofunction:: get_group_rank ``` diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index c7d3a93f7352..2c1a2e8cbb6b 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -619,6 +619,10 @@ Available options: and reallocate buffers across multiple streams, especially when the capture DAG frequently reaches joined frontiers. +* ``per_process_memory_fraction`` option limits the amount of memory that can be allocated + on all the CUDA devices to a specified fraction of the available memory. This is a value + between 0 and 1. Attempting to allocate more memory will raise an out of memory error. + .. note:: Some stats reported by the @@ -1720,6 +1724,16 @@ and can be used to share memory across graphs as shown:: g1.replay() g2.replay() +It's also safe to share a memory pool across separate graphs that do not depend +on each other's outputs, provided they never run concurrently. +Be aware that replaying one graph can clobber another graph's outputs when +they share a pool, unless :meth:`~torch.Tensor.clone` is called on the outputs +beforehand. +This pattern is frequently used in inference servers that accept variable batch +sizes at runtime. +vLLM is a notable example; see `here `__ +and `here `__. + With :func:`torch.cuda.make_graphed_callables`, if you want to graph several callables and you know they'll always run in the same order (and never concurrently) pass them as a tuple in the same order they'll run in the live workload, and diff --git a/setup.py b/setup.py index dd8a52cbeb7c..31e78d0245d9 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 4763621f6039..f1747acc31fc 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp @@ -44,6 +45,10 @@ endif() # Disable unused-variable warnings for variables that are only used to test compilation target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable) target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable) +# Add -Wno-dangling-pointer for GCC 13 +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_aoti_abi_check -Wno-dangling-pointer) +endif() foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS}) foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) diff --git a/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp new file mode 100644 index 000000000000..184c0ade8360 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp @@ -0,0 +1,52 @@ +#include + +#include + +#include + +using torch::headeronly::HeaderOnlyArrayRef; + +TEST(TestHeaderOnlyArrayRef, TestEmpty) { + HeaderOnlyArrayRef arr; + ASSERT_TRUE(arr.empty()); +} + +TEST(TestHeaderOnlyArrayRef, TestSingleton) { + float val = 5.0f; + HeaderOnlyArrayRef arr(val); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 1); + EXPECT_EQ(arr[0], val); +} + +TEST(TestHeaderOnlyArrayRef, TestAPIs) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 7); + for (size_t i = 0; i < arr.size(); i++) { + EXPECT_EQ(arr[i], i + 1); + EXPECT_EQ(arr.at(i), i + 1); + } + EXPECT_EQ(arr.front(), 1); + EXPECT_EQ(arr.back(), 7); + ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3))); +} + +TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr({1, 2, 3, 4, 5, 6, 7}); + auto res_vec = arr.vec(); + for (size_t i = 0; i < vec.size(); i++) { + EXPECT_EQ(vec[i], res_vec[i]); + } +} + +TEST(TestHeaderOnlyArrayRef, TestFromRange) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec.data() + 3, vec.data() + 7); + auto res_vec = arr.vec(); + for (size_t i = 0; i < res_vec.size(); i++) { + EXPECT_EQ(vec[i + 3], res_vec[i]); + } +} diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 8261aae3b560..a92832a4d04c 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -70,6 +70,13 @@ if(NOT MSVC) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12) target_compile_options_if_supported(test_api "-Wno-error=nonnull") endif() + + # Add -Wno-error=array-bounds for GCC 13+ + # See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113239 + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_api "-Wno-error=array-bounds") + endif() + endif() if(INSTALL_TEST) diff --git a/test/cpp/api/init_baseline.py b/test/cpp/api/init_baseline.py index 47b202e86311..4042657b4d5c 100644 --- a/test/cpp/api/init_baseline.py +++ b/test/cpp/api/init_baseline.py @@ -64,7 +64,7 @@ def run(initializer): def main(): initializer_parameter_map = {} - for initializer in INITIALIZERS.keys(): + for initializer in INITIALIZERS: sys.stderr.write(f"Evaluating {initializer} ...\n") initializer_parameter_map[initializer] = run(initializer) diff --git a/test/cpp/api/optim_baseline.py b/test/cpp/api/optim_baseline.py index 7e278d4e4208..e1a3c91b7128 100644 --- a/test/cpp/api/optim_baseline.py +++ b/test/cpp/api/optim_baseline.py @@ -130,7 +130,7 @@ def main(): options = parser.parse_args() optimizer_parameter_map = {} - for optimizer in OPTIMIZERS.keys(): + for optimizer in OPTIMIZERS: sys.stderr.write(f"Evaluating {optimizer} ...\n") optimizer_parameter_map[optimizer] = run( optimizer, options.iterations, options.sample_every diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 58c812b08ccc..7154322641c3 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -47,20 +47,10 @@ Tensor sgd_out_of_place( STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); - int64_t *param_sizes; - int64_t *param_strides; - aoti_torch_get_sizes(param.get(), ¶m_sizes); - aoti_torch_get_strides(param.get(), ¶m_strides); + // testing Tensor strides + stride + STD_TORCH_CHECK(param.strides()[0] == param.stride(0)); - int32_t param_dtype; - aoti_torch_get_dtype(param.get(), ¶m_dtype); - - int32_t param_device_type; - aoti_torch_get_device_type(param.get(), ¶m_device_type); - - AtenTensorHandle out_ath; - aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); - auto out = Tensor(out_ath); + auto out = new_empty(param, param.sizes()); sgd_math( reinterpret_cast(param.data_ptr()), @@ -311,10 +301,9 @@ void boxed_fill_infinity( } Tensor my_pad(Tensor t) { - std::vector padding = {1, 2, 2, 1}; std::string mode = "constant"; double value = 0.0; - return pad(t, padding, mode, value); + return pad(t, {1, 2, 2, 1}, mode, value); } void boxed_my_pad( @@ -342,6 +331,11 @@ void boxed_my_narrow( } Tensor my_new_empty_dtype_variant(Tensor t) { + // Still using a std::vector below even though people can just pass in an + // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef) + // directly. + // This is to test that passing in a std::vector works for BC. (It gets + // implicitly converted to HeaderOnlyArrayRef too!) std::vector sizes = {2, 5}; auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); @@ -353,9 +347,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui } Tensor my_new_zeros_dtype_variant(Tensor t) { - std::vector sizes = {2, 5}; auto dtype = std::make_optional(at::ScalarType::Float); - return new_zeros(t, sizes, dtype); + return new_zeros(t, {2, 5}, dtype); } void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { @@ -429,8 +422,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) } Tensor my_amax_vec(Tensor t) { - std::vector v = {0,1}; - return amax(t, v, false); + return amax(t, {0,1}, false); } void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index f0316fde9f2c..1aaaf645c58d 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -208,7 +208,7 @@ def test_quantized_checkpoint_loading(self) -> None: # Create model.safetensors.index.json with weight mapping weight_map = {} - for key in quantized_checkpoint.keys(): + for key in quantized_checkpoint: weight_map[key] = "model.safetensors" index_data = { @@ -245,7 +245,7 @@ def test_quantized_checkpoint_loading(self) -> None: sorted(original_tensors.keys()), sorted(state_dict_to_load.keys()) ) - for tensor_name in original_tensors.keys(): + for tensor_name in original_tensors: original = original_tensors[tensor_name] loaded = state_dict_to_load[tensor_name] diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index dee38d040346..b4532a86e305 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -498,7 +498,7 @@ def _run_test_mixed_precision_e2e( for name, tensor in state_dict.items(): # Parameters and buffers are checkpointed in their # original dtypes, which may be different. - if name in named_buffers.keys(): + if name in named_buffers: self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) else: self.assertEqual( diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 07442f34c894..abc37f17a74d 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -5,8 +5,16 @@ import torch import torch.distributed as dist from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor import ( + DeviceMesh, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -42,22 +50,24 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_ids=True, record_output=True + ) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) - aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) + aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) - _c10d_functional::wait_tensor(t: f32[8, 32]) - aten::mm(t: f32[1, 8], t: f32[8, 32]) - (dt: f32[8, 32]| S(0)) - aten::sum(dt: f32[8, 32]| S(0)) - aten::sum(t: f32[1, 32])""", + redistribute_input(t$2: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] + _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] + aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + aten::sum(dt$6: f32[8, 32]| S(0)) + aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) @@ -424,6 +434,31 @@ def forward(self, x): ][-1] self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace) + def test_pretty_print_dtensor_make_fx(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + A = torch.randn(8, 32) + B = torch.randn(32, 32) + dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_() + dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_() + + def f(dA, dB): + dy = dA @ dB + loss = dy.sum() + loss.backward() + return dA.grad, dB.grad + + # We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode. + # make_fx has some logic to ensure we don't accidentally stash real tensors in the graph + # so we won't stash our DTensors properly if they don't hold Fake inner tensors + gm = make_fx(f, tracing_mode="fake")(dA, dB) + # DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph + gm.graph.eliminate_dead_code() + gm.recompile() + # Colored is nice for actual viewing, not using in this test though + gm_str = gm.print_readable(colored=False, print_output=False) + self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str) + instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index eaf3a4042060..6c3485f9d702 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,7 +3,8 @@ import itertools import random import unittest -from typing import Any, Callable, ClassVar, Optional +from collections.abc import Callable +from typing import Any, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 985e2d5f151a..2a1cb2b5580c 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1189,9 +1189,7 @@ def _test_sequence_num_incremented(self, process_group, ranks): self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { - rank_to_seq_num[i] - for i in rank_to_seq_num.keys() - if i not in [0, 2] + rank_to_seq_num[i] for i in rank_to_seq_num if i not in [0, 2] } self.assertEqual(len(expected_same), 1) self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1]) @@ -1558,7 +1556,7 @@ def test_debug_level(self): } invalid_debug_modes = ["foo", 0, 1, -1] - for mode in mapping.keys(): + for mode in mapping: os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) dist.set_debug_level_from_env() set_debug_mode = dist.get_debug_level() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c117bc810b11..ef7ed5282816 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2,6 +2,7 @@ import copy import json +import logging import os import pickle import random @@ -21,6 +22,7 @@ import torch import torch.distributed as c10d import torch.distributed._functional_collectives as _functional_collectives +from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT if not c10d.is_available() or not c10d.is_nccl_available(): @@ -47,12 +49,15 @@ from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( + get_required_world_size, get_timeout, init_multigpu_helper, MultiProcessTestCase, requires_multicast_support, requires_nccl, + requires_nccl_shrink, requires_nccl_version, + requires_world_size, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, @@ -88,6 +93,53 @@ ) +_start_time = time.time() +_logger = logging.getLogger(__name__) + + +def _ts(): + return time.time() - _start_time + + +def configure(level=logging.INFO, force=False): + try: + logging.basicConfig( + level=level, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + force=force, + ) + except TypeError: + logging.basicConfig( + level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s" + ) + + +def log_test_info(rank, message): + _logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message) + + +def log_test_success(rank, message): + _logger.info("[%7.3fs][Rank %s] βœ… %s", _ts(), rank, message) + + +def log_test_validation(rank, message): + _logger.info("[%7.3fs][Rank %s] βœ“ %s", _ts(), rank, message) + + +def log_test_warning(rank, message): + _logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message) + + +def log_test_error(rank, message): + _logger.error("[%7.3fs][Rank %s] βœ— %s", _ts(), rank, message) + + +_log_configure = configure + + +_log_configure(level=logging.INFO, force=True) + + class RendezvousEnvTest(TestCase): @retry_on_connect_failures @requires_nccl() @@ -317,7 +369,7 @@ def tearDown(self): @property def world_size(self): - return 2 + return get_required_world_size(self, 2) @property def rank_to_GPU(self): @@ -1255,6 +1307,628 @@ def test_set_process_group_desc(self): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_basic(self): + """Test basic shrink_group functionality.""" + self._perform_shrink_test([1], "Basic shrink test") + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_validation(self): + """Test input validation in shrink_group.""" + device, pg = self._setup_shrink_test("validation") + + def _test_invalid_input(ranks, description, expected_exception): + """Helper to test invalid inputs.""" + try: + c10d.shrink_group(ranks) + self.fail(f"Expected {expected_exception.__name__} for {description}") + except expected_exception: + log_test_validation(self.rank, f"βœ“ {description}") + except Exception: + if expected_exception is Exception: # Accept any exception + log_test_validation(self.rank, f"βœ“ {description}") + else: + raise + + # Test cases + _test_invalid_input([], "Empty exclusion list", ValueError) + if self.world_size > 1: + _test_invalid_input([0, 0, 1], "Duplicate ranks", Exception) + _test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception) + + log_test_success(self.rank, "All validation tests passed") + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_backend_properties(self): + """Test that backend properties are preserved after shrinking.""" + + test_name = "Backend Properties Test" + ranks_to_exclude = [0] + + # Reuse _setup_shrink_test for complete setup (device, environment, and process group) + device, pg = self._setup_shrink_test("backend_properties") + + # Follow _perform_shrink_test pattern from here + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # Store original backend property values (not references) before shrinking + original_timeout = None + original_high_priority = None + if not is_excluded: + original_backend = pg._get_backend(device) + original_timeout = original_backend.options._timeout + original_high_priority = original_backend.options.is_high_priority_stream + log_test_info( + self.rank, + f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}", + ) + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + dist.destroy_process_group() # hang without it + return + + # Only non-excluded ranks proceed with shrink (same as _perform_shrink_test) + log_test_info(self.rank, "Non-excluded rank calling shrink_group") + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + + # Reuse _validate_shrunk_group helper (same as _perform_shrink_test) + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + # Add custom backend properties validation + new_backend = shrunk_pg._get_backend(device) + log_test_info(self.rank, "Validating backend properties are preserved") + + new_timeout = new_backend.options._timeout + new_high_priority = new_backend.options.is_high_priority_stream + + log_test_info( + self.rank, + f"Timeout comparison - original: {original_timeout}, new: {new_timeout}", + ) + self.assertEqual( + original_timeout, new_timeout, f"{test_name}: timeout not preserved" + ) + + log_test_info( + self.rank, + f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}", + ) + self.assertEqual( + original_high_priority, + new_high_priority, + f"{test_name}: high_priority_stream not preserved", + ) + + log_test_validation( + self.rank, f"{test_name}: Backend properties preserved successfully" + ) + log_test_success( + self.rank, f"{test_name} successful (shrink + backend validation)" + ) + + # Cleanup (same as _perform_shrink_test) + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_multiple_comms(self): + """Test shrink_group with multiple communicators and subgroup invalidation.""" + + device, pg = self._setup_shrink_test("multiple_comms") + + # Create subgroup [0, 1] and test shrinking it + subgroup = c10d.new_group([0, 1]) + if self.rank <= 1: + # Shrink subgroup: exclude rank 1 + if self.rank == 0: # Only rank 0 remains + shrunk_subgroup = c10d.shrink_group([1], group=subgroup) + self.assertEqual(shrunk_subgroup.size(), 1) + # Test communication on shrunk subgroup + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_subgroup) + self.assertEqual(tensor.item(), 0) # Only rank 0 + log_test_success(self.rank, "Subgroup shrinking successful") + + dist.barrier() # Sync before default group test + + # Shrink default group: exclude last rank + ranks_to_exclude = [self.world_size - 1] + if self.rank not in ranks_to_exclude: + shrunk_default = c10d.shrink_group(ranks_to_exclude) + expected_size = self.world_size - 1 + self.assertEqual(shrunk_default.size(), expected_size) + + # Test collective on shrunk default group + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_default) + expected_sum = sum( + range(self.world_size - 1) + ) # 0 + 1 + ... + (world_size-2) + self.assertEqual(tensor.item(), expected_sum) + log_test_success(self.rank, "Default group shrinking successful") + + # Note: After shrinking default group, the old subgroup is invalid + # due to global rank reassignment + + dist.destroy_process_group() + + def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude): + """Helper method to test shrink_group with a specific flag.""" + if self.world_size < 2: + log_test_info(self.rank, f"Skipping (needs β‰₯2 GPUs, got {self.world_size})") + return + ranks_to_exclude = [rank_to_exclude] + log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})") + if flag_name == "NCCL_SHRINK_ABORT": + log_test_info( + self.rank, + "ABORT flag will terminate ongoing operations before shrinking", + ) + + self._perform_shrink_test( + ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag + ) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_flags(self): + """Test shrink_group with different shrink flags.""" + # Test ABORT flags + log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag") + self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_nccl_config(self): + """Verify that passing NCCL config via pg_options influences the shrunk group's backend options.""" + device, pg = self._setup_shrink_test("config") + if self.rank == self.world_size - 1: + # excluded rank should not call shrink_group + dist.destroy_process_group() + return + + # Prepare pg_options with NCCL config overrides + # Capture parent's current backend options to ensure we can prove override vs inherit + parent_backend = pg._get_backend(torch.device("cuda")) + parent_hp = parent_backend.options.is_high_priority_stream + parent_blocking = parent_backend.options.config.blocking + + # Choose overrides that differ from the parent (flip where possible) + override_hp = not parent_hp + if parent_blocking in (0, 1): + override_blocking = 1 - parent_blocking + else: + # If undefined or unexpected, set to 1 which is a concrete value + override_blocking = 1 + + opts = c10d.ProcessGroupNCCL.Options() + opts.is_high_priority_stream = override_hp + opts.config.blocking = override_blocking + + shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts) + + # Validate backend options propagated + backend = shrunk_pg._get_backend(torch.device("cuda")) + # is_high_priority_stream should exactly match our override and differ from parent + self.assertEqual(backend.options.is_high_priority_stream, override_hp) + self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp) + # config is a struct; check representative field and difference from parent when meaningful + self.assertEqual(backend.options.config.blocking, override_blocking) + if parent_blocking in (0, 1): + self.assertNotEqual(backend.options.config.blocking, parent_blocking) + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_performance(self): + """Test shrink_group performance and regression detection.""" + import time + + ranks_to_exclude = self._get_default_ranks_to_exclude() + is_excluded = self.rank in ranks_to_exclude + + if not ranks_to_exclude: + log_test_info(self.rank, "Skipping performance test (world_size=1)") + return + + log_test_info(self.rank, f"Performance test with {self.world_size} processes") + device, pg = self._setup_shrink_test("performance") + + if not is_excluded: + log_test_info(self.rank, "Measuring shrink_group performance") + start_time = time.time() + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + end_time = time.time() + + elapsed_time = end_time - start_time + log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s") + + # Regression check: should complete within reasonable time + self.assertLess( + elapsed_time, + 30.0, + f"shrink_group took {elapsed_time:.3f}s, possible regression", + ) + + # Test collective performance + expected_size = self.world_size - len(ranks_to_exclude) + self._validate_shrunk_group(shrunk_pg, expected_size, "performance") + + collective_start = time.time() + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, "performance" + ) + collective_time = time.time() - collective_start + + log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s") + log_test_success(self.rank, "Performance test passed") + else: + log_test_info(self.rank, "Excluded rank - waiting") + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(4) + def test_shrink_group_multiple_exclusions(self): + """Test shrink_group with multiple ranks excluded at once.""" + # Scale exclusions with world size + ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2 + + self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test") + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_multiple_iterations(self): + """Test multiple shrink operations in sequence.""" + log_test_info( + self.rank, + f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}", + ) + + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + _ = self._create_process_group_nccl(store, self.opts(), device_id=device) + + # Track current effective world size throughout shrinking operations + current_world_size = self.world_size + log_test_info(self.rank, f"Initial world_size: {current_world_size}") + + # First shrinking: exclude the last rank(s) + first_exclusion = [self.world_size - 1] + if self.world_size >= 6: + first_exclusion.append( + self.world_size - 2 + ) # Exclude last two ranks for larger sizes + + log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}") + + if self.rank not in first_exclusion: + # Only non-excluded ranks should call shrink_group + first_pg = c10d.shrink_group(first_exclusion) + self.assertIsNotNone(first_pg) + # IMPORTANT: Update world size after first shrinking + current_world_size = first_pg.size() + expected_first_size = self.world_size - len(first_exclusion) + log_test_info( + self.rank, + f"After first shrinking: world_size {self.world_size} -> {current_world_size}", + ) + self.assertEqual(first_pg.size(), expected_first_size) + + # Second shrinking: exclude another rank from the remaining group + # Choose a rank that's in the middle range + if current_world_size >= 3: + second_exclusion = [ + current_world_size - 1 + ] # Exclude the new "last" rank + log_test_info( + self.rank, + f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}", + ) + + if self.rank not in second_exclusion: + # Only non-excluded ranks should call shrink_group for second iteration + second_pg = c10d.shrink_group(second_exclusion, group=first_pg) + self.assertIsNotNone(second_pg) + # IMPORTANT: Update world size after second shrinking + final_world_size = second_pg.size() + expected_final_size = current_world_size - len(second_exclusion) + log_test_info( + self.rank, + f"After second shrinking: world_size {current_world_size} -> {final_world_size}", + ) + self.assertEqual(second_pg.size(), expected_final_size) + + # Test collective on final group + tensor = torch.full((1,), self.rank).cuda(device) + log_test_info( + self.rank, + f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}", + ) + c10d.all_reduce(tensor, group=second_pg) + log_test_info( + self.rank, + f"Final all_reduce completed, result: {tensor.item()}", + ) + + # Calculate expected sum of remaining ranks + all_excluded = set(first_exclusion + second_exclusion) + remaining_ranks = [ + r for r in range(self.world_size) if r not in all_excluded + ] + expected_sum = sum(remaining_ranks) + log_test_info( + self.rank, + f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}", + ) + self.assertEqual(tensor.item(), expected_sum) + log_test_info(self.rank, "Final verification passed") + else: + log_test_info( + self.rank, + "This rank excluded in second shrinking, not calling shrink_group", + ) + else: + log_test_info( + self.rank, "Skipping second shrinking (remaining group too small)" + ) + else: + log_test_info( + self.rank, + "This rank excluded in first shrinking, not calling shrink_group", + ) + + log_test_info(self.rank, "Destroying process group") + dist.destroy_process_group() + log_test_info(self.rank, "test_shrink_group_multiple_iterations completed") + + # Helper methods for optimized shrink group tests + def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True): + """Common setup for shrink group tests.""" + os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" + world_size = world_size or self.world_size + store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size) + device = torch.device(f"cuda:{self.rank}") + c10d.init_process_group( + "nccl", + world_size=world_size, + rank=self.rank, + store=store, + pg_options=self.opts(), + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + + if warmup: + c10d.all_reduce(torch.ones(1).cuda(device), group=pg) + + return device, pg + + def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""): + """Validate properties of a shrunk process group.""" + self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None") + actual_size = shrunk_pg.size() + self.assertEqual( + actual_size, expected_size, f"{test_name}: group size mismatch" + ) + + new_rank = shrunk_pg.rank() + self.assertTrue( + 0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}" + ) + + log_test_info( + self.rank, + f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}", + ) + return new_rank + + def _test_collective_on_shrunk_group( + self, shrunk_pg, device, ranks_to_exclude, test_name="" + ): + """Test collective communication on shrunk group and verify correctness.""" + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + c10d.all_reduce(test_tensor, group=shrunk_pg) + + result = test_tensor.item() + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + + self.assertEqual( + result, expected_sum, f"{test_name}: collective result mismatch" + ) + log_test_info( + self.rank, f"{test_name}: collective passed ({result} == {expected_sum})" + ) + return result + + def _perform_shrink_test( + self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True + ): + """Complete shrink test flow: setup, shrink, validate, test collective, cleanup. + + Consistent API: All ranks perform setup to initialize distributed environment. + ONLY non-excluded ranks call shrink_group() for both default and non-default groups. + Excluded ranks perform setup, then exit without calling shrink_group() or waiting. + """ + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # All ranks (including excluded ones) perform setup to initialize distributed environment + device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_")) + is_default_group = pg == c10d.distributed_c10d._get_default_group() + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + if shrink_flags & NCCL_SHRINK_ABORT: + log_test_info(self.rank, f"Using abort for excluded rank {self.rank}") + pg._get_backend(torch.device(device)).abort() + log_test_info( + self.rank, f"cleanup resources for excluded rank {self.rank}" + ) + dist.destroy_process_group() + log_test_info(self.rank, f"Excluded rank {self.rank} - exit") + else: + log_test_info( + self.rank, f"Using regular destroy for excluded rank {self.rank}" + ) + dist.destroy_process_group() + return None + + # Only non-excluded ranks proceed with shrink + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group})", + ) + shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags) + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done", + ) + + # Non-excluded ranks: validate and test the new group + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + if with_collective: + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, test_name + ) + log_test_success(self.rank, f"{test_name} successful (shrink + collective)") + else: + log_test_success(self.rank, f"{test_name} successful (shrink only)") + + dist.destroy_process_group() + return shrunk_pg + + def _get_default_ranks_to_exclude(self): + """Get default ranks to exclude based on world size.""" + if self.world_size <= 1: + return [] + return [self.world_size - 1] # Exclude last rank by default + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_vs_abort_reinit_performance(self): + """Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability).""" + log_test_info(self.rank, "=== TEST 1: abort+reinit ===") + + device, pg1 = self._setup_shrink_test("_perf_reinit") + torch.cuda.synchronize(device) + + # Test 1: Traditional abort + reinit + start_time = time.perf_counter() + dist.destroy_process_group() + + device, new_pg = self._setup_shrink_test("perf_shrink_test1") + reinit_time = time.perf_counter() - start_time + + # Test collective with original rank values for fair comparison (non-blocking mode) + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True) + work.wait() + + torch.cuda.synchronize(device) + + # Verify correctness + expected_sum = sum(r for r in range(self.world_size)) + self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed") + + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + dist.destroy_process_group(new_pg) + + # Test 2: shrink_group with NCCL_SHRINK_ABORT + log_test_info(self.rank, "=== TEST 2: shrink_group ===") + + ranks_to_exclude = [self.world_size - 1] + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix + + shrink_time = 0 + if not is_excluded: + torch.cuda.synchronize(device) # Ensure accurate timing + start_time = time.perf_counter() + shrunk_pg = c10d.shrink_group( + ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT + ) + c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg) + shrink_time = time.perf_counter() - start_time + + # Test collective communication on shrunk group (non-blocking mode) + test_tensor = torch.full( + (1,), self.rank, device=device, dtype=torch.float32 + ) + work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True) + work.wait() + + # Verify correctness + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + self.assertEqual( + test_tensor.item(), + expected_sum, + "shrink_test: collective result mismatch", + ) + + torch.cuda.synchronize(device) # Ensure operations complete + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + dist.destroy_process_group() + else: + log_test_info(self.rank, "Excluded from shrink test - exiting immediately") + dist.destroy_process_group() + return + + # Performance analysis (only for participating ranks) + if shrink_time > 0 and reinit_time > 0: + speedup = reinit_time / shrink_time + time_saved = reinit_time - shrink_time + + log_test_info(self.rank, "=== PERFORMANCE RESULTS ===") + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s") + log_test_info(self.rank, f"speedup: {speedup:.2f}x") + + if speedup > 1.1: + log_test_success(self.rank, "shrink_group significantly faster") + elif speedup > 0.9: + log_test_info(self.rank, "β‰ˆ comparable performance") + else: + log_test_warning(self.rank, "abort+reinit faster") + + log_test_info(self.rank, "Performance test completed") + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_deterministic_mode_no_break(self): @@ -5115,6 +5789,229 @@ def test_coalescing_manager_collective(self, timing_enabled): else: self.assertTrue("duration_ms" not in t["entries"][0]) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_circular_buffer_full(self, timing_enabled): + """ + Test that when the circular buffer in entries_ is full and we call reset, + then fill the buffer with new entries, dump_entries returns only the new + entries and not the old ones. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely with 10 entries + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify buffer is full with 10 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Now reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add new entries after reset - fill the buffer completely again + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we get exactly 10 new entries, not 20 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Verify all entries have the expected properties (from after reset) + # After reset, record IDs should start from 0 again + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertEqual(entry["profiling_name"], "nccl:all_reduce") + self.assertIn("record_id", entry) + # Record IDs should be sequential starting from 0 after reset + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_partial_overwrite(self, timing_enabled): + """ + Test that when the circular buffer is full, we reset, and then add fewer + entries than the buffer size, we only get the new entries. + This tests that old entries at the end of the circular buffer are properly + filtered out based on reset_epoch. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add only 3 new entries (much less than buffer size) + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we only get the 3 new entries, not 10 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 3) + + # Verify record IDs start from 0 after reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_wraparound(self, timing_enabled): + """ + Test that when we reset in the middle of the circular buffer and then + wrap around, dump_entries correctly returns only entries from the current + epoch in the correct order. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill half the buffer + for _ in range(5): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset at this point (reset happens at index 5) + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Now add 8 entries, which will wrap around + # (5->9 fills rest of buffer, then 0->2 wraps around) + for _ in range(8): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should get exactly 8 entries, properly ordered + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 8) + + # Entries should be in chronological order + # The dump_entries() method returns entries from next_ to end, then 0 to next_ + # After filtering old entries, we should have 8 entries in order + # Verify record IDs start from 0 after reset (id_ is reset in reset_all()) + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_multiple_resets(self, timing_enabled): + """ + Test multiple consecutive resets to ensure each reset properly increments + the epoch and filters out entries from previous epochs. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # First batch: 2 entries + for _ in range(2): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # First reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Second batch: 3 entries + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Second reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Third batch: 4 entries + for _ in range(4): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should only see the last 4 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 4) + + # Verify record IDs start from 0 after the last reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + def check_if_test_is_skipped(fn): def wrapper(self, *args, **kwargs): @@ -5446,6 +6343,14 @@ def test_comm_recursive_split_group(self): if self.rank == 6 or self.rank == 7: dist.broadcast(tensor2, 6, group=ng2) self.assertEqual(tensor2, torch.full((1,), 6)) + + # Test the case when the split changes the pg option of split group + # while the parent pg option is not changed. + new_pg = c10d.new_group([0, 1, 2, 3, 4, 5, 6, 7], device_id=device) + backend_new_pg = new_pg._get_backend(torch.device(device)) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) + c10d.split_group(new_pg, [[0, 2, 4, 6], [1, 3, 5, 7]]) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) # a barrier and a cuda sync before destroying all pgs. dist.barrier(pg) torch.cuda.synchronize() diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index ac3103e09341..daa9bf2e309f 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1985,6 +1985,7 @@ def _reorder_communication_preserving_peak_memory( "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + _reorder_communication_preserving_peak_memory, sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], @@ -2046,11 +2047,6 @@ def _reorder_communication_preserving_peak_memory( assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) self.assertEqual(len(node_stats), 4) - it = iter(node_stats.values()) - node_stat0 = next(it) - self.assertTrue(node_stat0.limiting_factor == "None") - node_stat1 = next(it) - self.assertTrue("collective ordering" in node_stat1.limiting_factor) @skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581 @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index 114780627e33..fa081243c281 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -7,6 +7,8 @@ import torch.distributed as dist from torch.distributed._local_tensor import ( local_tensor_mode, + LocalIntNode, + LocalRunnerMode, LocalTensor, LocalTensorMode, ) @@ -17,8 +19,10 @@ Partial, Replicate, Shard, + zeros, ) from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import reduce_local_int class LocalTensorTestBase(TestCase): @@ -124,14 +128,14 @@ def test_basic_arithmetic_operations(self): self.assertEqual(len(result_add._local_tensors), 2) # Verify the operation was applied to each local tensor - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + identical_local_tensors[rank] self.assertEqual(result_add._local_tensors[rank], expected) # Test multiplication result_mul = lt1 * 2.0 self.assertIsInstance(result_mul, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] * 2.0 self.assertEqual(result_mul._local_tensors[rank], expected) @@ -159,7 +163,7 @@ def test_mixed_operations_with_regular_tensors(self): result = lt + regular_tensor self.assertIsInstance(result, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + regular_tensor self.assertEqual(result._local_tensors[rank], expected) @@ -208,14 +212,14 @@ def test_collectives_within_local_tensor_mode(self): dist.all_reduce(lt_sum, group=fake_pg) expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]]) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test broadcast within mode lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.broadcast(lt_broadcast, src=0, group=fake_pg) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0]) # Test that regular operations still work @@ -289,21 +293,21 @@ def test_collective_reduction_operations(self): lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg) expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test MAX reduction lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg) expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_max._local_tensors[rank], expected_max) # Test MIN reduction lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg) expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_min._local_tensors[rank], expected_min) def test_all_reduce_collective(self): @@ -324,7 +328,7 @@ def test_all_reduce_collective(self): # Verify all ranks have the sum of all tensors (after adding 1 to each) expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]]) - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) def test_broadcast_collective(self): @@ -344,7 +348,7 @@ def test_broadcast_collective(self): # Verify all ranks have rank 1's original tensor expected_broadcast = different_tensors[1] - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast) def test_all_gather_collective(self): @@ -411,5 +415,78 @@ def test_dtensor_addmm(self): self.assertEqual(full_tensor, local_res) +from torch.distributed._local_tensor._c10d import local_p2p_op, wait_all + + +class TestLocalRunner(LocalTensorTestBase): + world_size = 6 + + @staticmethod + def _get_pp_peer(pp_index, mesh, dim, dir): + pp_meshes = mesh._get_all_submeshes(dim) + pp_ret = {} + for pp_mesh in pp_meshes: + global_rank = pp_mesh.mesh[pp_index].item() + global_peer = pp_mesh.mesh[(pp_index + dir) % pp_mesh.size()].item() + pp_ret[global_rank] = global_peer + + return torch.SymInt(LocalIntNode(pp_ret)) + + def _run_dp_pp( + self, + mesh: DeviceMesh, + pp_index: int, + actual: list[torch.Tensor | None], + expected: list[torch.Tensor | None], + ) -> None: + ltm = LocalTensorMode(mesh.size()) + with ltm: + dp_mesh = mesh["dp"] + pp_mesh = mesh["pp"] + + x = torch.rand(2, 4) + xd = distribute_tensor(x, dp_mesh, [Shard(0)]) + xd = xd * 2 + x = x * 2 + + yd = zeros(*xd.shape, device_mesh=dp_mesh, placements=[Shard(0)]) + + if pp_index != pp_mesh.size(0) - 1: + # Send to next pp rank + pp_next_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", +1) + local_p2p_op(pp_next_rank, xd, dist.isend) + expected[pp_index + 1] = ltm.tensor_map( + x, + lambda r, t: t + if reduce_local_int(pp_next_rank, lambda vals: r in vals.values()) + else torch.zeros_like(t), + ) + + if pp_index != 0: + # Receive from prev pp rank + pp_prev_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", -1) + rw = local_p2p_op(pp_prev_rank, yd, dist.irecv) + wait_all(rw) + + y = yd.full_tensor() + actual[pp_index] = y + + def test_dp_pp(self): + pp_size = 3 + mesh = init_device_mesh( + "cpu", (self.world_size // pp_size, pp_size), mesh_dim_names=("dp", "pp") + ) + actual: list[torch.Tensor | None] = [None] * pp_size + expected: list[torch.Tensor | None] = [None] * pp_size + with LocalRunnerMode( + self.world_size, + pp_size, + lambda pp_index: self._run_dp_pp(mesh, pp_index, actual, expected), + ): + pass + + self.assertEqual(actual, expected) + + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 8810a30aaf3b..8ebf35f3f0d3 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -283,7 +283,7 @@ def test_fn(): ) def test_bisect_pre_grad_graph(self): def f(x): - for i in range(5): + for _ in range(5): x = x + 1 return x.relu() diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 966acd1d8139..4a4d2ff87718 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -36,6 +36,15 @@ class DummyUserDict(UserDict): pass +class FakeMapping: + def __init__(self, value: Any) -> None: + self._value = value + self.keys = lambda: ["a", "b", "c"] # not required to be a method + + def __getitem__(self, key: str) -> Any: + return self._value + + class DictTests(torch._dynamo.test_case.TestCase): def test_dict_subclass_instantiation(self): def fn(x): @@ -666,6 +675,18 @@ def fn(): for k1, m2 in zip(modules, module_dict.children()): self.assertTrue(modules[k1] is m2) + # FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update + @unittest.expectedFailure + def test_dict_construct_from_mapping_like(self): + def fn(x): + fm = FakeMapping(x) + d = dict(fm, x=x) + return d + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + def test_dict_subclass_initialization_in_graph(self): for super_class in ( OrderedDict, @@ -1087,12 +1108,52 @@ def f(x): self.assertEqual(ref, res) - @unittest.expectedFailure + def test_newly_constructed_default_dict_no_default_factory(self): + def f1(x): + d = defaultdict() + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + x = torch.ones(2) + ref = f1(x) + res = torch.compile(f1, backend="eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + + def f2(x): + d = defaultdict(None) + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + ref = f2(x) + res = torch.compile(f2, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def f3(x): + d = defaultdict(None, {1: 10}) + d[1] += 42 + try: + d[2] += 24 + except KeyError: + d[2] = 1 + return x + 1, d + + ref = f3(x) + res = torch.compile(f3, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + def test_newly_constructed_default_dict_with_dict(self): def f(x): - d = defaultdict(dict, {2: {"a": 1}}) - d[0] = {"b": 2} - return x + 1, d + d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406 + dd = defaultdict(list, d, d=4, e=5) + dd["x"].append(42) + return x + 1, d, dd x = torch.ones(2) ref = f(x) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 004aee88a863..fc9284a3c954 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -8,21 +8,11 @@ from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.output_graph import FakeRootModule from torch._dynamo.test_case import TestCase -from torch._dynamo.testing import ( - AotEagerAndRecordGraphs, - extract_graph_and_tracker, - normalize_gm, -) +from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm from torch.compiler import allow_in_graph from torch.utils._ordered_set import OrderedSet -def extract_graph(fn, *args, **kwargs): - backend = AotEagerAndRecordGraphs() - result = torch.compile(backend=backend)(fn)(*args, **kwargs) - return result, backend.graphs, backend.fw_graphs - - def graph_str(gm): return normalize_gm(gm.print_readable(print_output=False)) @@ -40,7 +30,7 @@ def tearDown(self): super().tearDown() def run_and_return_graphs(self, fn, *args, **kwargs): - return extract_graph(fn, *args, **kwargs) + return extract_graph(fn, *args, **kwargs)[0:3] def run_and_get_simple_graph(self): def fn(x, y): diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index 3858b827bd59..fd9e14c4c3f7 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import unittest -from collections.abc import Sequence -from typing import Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import torch import torch._dynamo diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 169f43ce0a07..b3e9df6a25cf 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13194,6 +13194,30 @@ def fn(x, y): self.assertEqual(actual, expected) + @parametrize_pytree_module + def test_pytree_tree_map_dict_order(self, pytree): + def fn(tree): + new_tree = pytree.tree_map(lambda x: x, tree) + return list(new_tree.keys()), list(new_tree.values()) + + x = torch.randn(3, 2) + fn_opt = torch.compile(fullgraph=True)(fn) + + tree1 = {"b": x + 2, "a": x, "c": x - 1} + expected1 = fn(tree1) + actual1 = fn_opt(tree1) + self.assertEqual(actual1, expected1) + + tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)]) + expected2 = fn(tree2) + actual2 = fn_opt(tree2) + self.assertEqual(actual2, expected2) + + tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1}) + expected3 = fn(tree3) + actual3 = fn_opt(tree3) + self.assertEqual(actual3, expected3) + @parametrize_pytree_module def test_pytree_tree_map_only(self, pytree): if not callable(getattr(pytree, "tree_map_only", None)): @@ -13219,6 +13243,27 @@ def mapper(x): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) + def test_pytree_register_constant_with_side_effect(self): + class Foo: + pass + + class Bar: + def __eq__(self, other): + return super().__eq__(other) + + def __hash__(self): + return 0 + + python_pytree.register_constant(Bar) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x, obj): + obj.attr = {3: Bar()} + return x + 1 + + inp = torch.ones(3) + self.assertEqual(fn(inp, Foo()), inp + 1) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index a615c653f56c..a6117bb4093a 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple, Optional, TYPE_CHECKING import torch import torch._dynamo @@ -7,6 +7,10 @@ from torch._dynamo.testing import CompileCounter, same +if TYPE_CHECKING: + from collections.abc import Callable + + """ This is an example of a pure-python version of autograd implemented by @zdevito. It represents a rather challenging test case for TorchDynamo diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c6138f7574fd..f3766fe0c973 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -48,6 +48,7 @@ CompileCounter, CompileCounterWithBackend, EagerAndRecordGraphs, + expectedFailureDynamic, rand_strided, same, skipIfNotPy312, @@ -7455,6 +7456,93 @@ def forward(self, x): msg, ) + @expectedFailureDynamic + def test_dynamo_default_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[1], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[0] + + # Step 3: Run with Step 1's inputs + # LRU cache will match against dynamic shape graph first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], dynamic_shapes_cache_entry) + self.assertEqual(c[1], static_shapes_cache_entry) + + @expectedFailureDynamic + def test_dynamo_disable_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + def run(): + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[0], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[1] + + # Step 3: Run with Step 1's inputs + # LRU cache is disabled, we should still have static entry first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], static_shapes_cache_entry) + self.assertEqual(c[1], dynamic_shapes_cache_entry) + + try: + torch._C._dynamo.eval_frame._set_lru_cache(False) + run() + finally: + torch._C._dynamo.eval_frame._set_lru_cache(True) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index e05e1304d286..1b81597977d7 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -1,11 +1,17 @@ # Owner(s): ["module: dynamo"] import functools +import re import unittest import weakref import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.graph_bytecode_inputs import ( + reset_user_object_tracking, + store_user_object_weakrefs, +) +from torch._dynamo.testing import extract_graph, remove_trailing_space from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda @@ -15,6 +21,14 @@ ) +def remove_file_comment(gm_str: str) -> str: + return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str)) + + +def print_graph(graph: torch.fx.GraphModule) -> str: + return remove_file_comment(graph.print_readable()) + + class TestStreams(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -36,9 +50,7 @@ def test_event_weakref(self): @requires_cuda def test_stream_enter_exit(self): - def fn(x, y): - s2 = torch.Stream() - s1 = torch.Stream() + def fn(x, y, s1, s2): with s1: z1 = torch.add(x, y) with s2: @@ -47,13 +59,36 @@ def fn(x, y): return y - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream()) expected = fn(*inp) - fn_opt = torch.compile(fn, fullgraph=True) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) @requires_cuda + @unittest.skip("Needs graph break support with annotation context") def test_stream_context_graph_break(self): def fn(x, y): s2 = torch.Stream() @@ -70,9 +105,16 @@ def fn(x, y): inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) expected = fn(*inp) - fn_opt = torch.compile(fn) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) self.assertEqual(expected, actual) + self.assertEqual(len(fw_graphs), 2) + self.assertExpectedInline(print_graph(fw_graphs[0]), """""") + self.assertExpectedInline(print_graph(fw_graphs[1]), """""") @requires_cuda def test_stream_input(self): @@ -155,35 +197,310 @@ def fn(x, s0, s1): self.assertEqual(s_act, s_exp) def test_nested_stream_enter_exit(self): - pass - + def fn(x, y, s0, s1, s2): + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = ( + torch.ones(2, 2) + 1, + torch.ones(2, 2), + torch.Stream(), + torch.Stream(), + torch.Stream(), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) + + @unittest.skip("Needs graph break support with annotation context") def test_stream_enter_exit_graph_break(self): pass + @unittest.skip("Needs graph break support with annotation context") def test_nested_stream_enter_exit_graph_break(self): pass def test_local_stream_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + with s1: + z1 = torch.add(x, y) + with s2: + z = torch.add(x, y) + y = z + 2 + z1 + + return y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 1} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) def test_local_stream_nested_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) def test_stream_with_mutation(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + x.add_(y) + with s0: + z1 = torch.add(y, y) + z0 = torch.add(z1, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1) + + # Annotation: {'stream': 2} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None + + # Annotation: {'stream': 0} + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + + # + copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None + return (add_2, add_3) +""", + ) + + def test_stream_backward(self) -> None: + def fn(x, y): + s2 = torch.Stream() + s0 = torch.Stream() + with s0: + y0 = 2 * x + y + with s2: + z = 2 * x + y + + return y0, z + + inp = ( + torch.ones(2, 2, requires_grad=True) + 1, + torch.ones(2, 2, requires_grad=True), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): + # Annotation: {'stream': 1} + mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None + return (add, add_1) +""", + ) + + actual[1].sum().backward() + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): + # Annotation: {'stream': 0} + mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2) + + # + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None + + # Annotation: {'stream': 1} + mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + + # + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + return (add_3, add_2) +""", + ) @requires_cuda - def test_run_opcheck(self): + def test_run_opcheck_fork_join(self): from torch._dynamo.variables.streams import fork_stream, join_stream from torch.library import opcheck - sample_inputs = [ - (0, torch.device("cuda:0"), 1, torch.device("cuda:1")), - (2, torch.device("cuda:2"), 3, torch.device("cuda:1")), - ] - for args in sample_inputs: - opcheck(fork_stream, args) - opcheck(join_stream, args) + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + store_user_object_weakrefs(s0, s1) + + sample_inputs = [ + (0, 1), + (1, 0), + ] + for args in sample_inputs: + opcheck(fork_stream, args) + opcheck(join_stream, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() + + @requires_cuda + def test_run_opcheck_wait_record(self): + from torch._dynamo.variables.streams import record_event, wait_event + from torch.library import opcheck + + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + e0 = torch.Event() + e1 = torch.Event() + store_user_object_weakrefs(s0, s1, e0, e1) + + sample_inputs = [ + (2, 0), + (3, 1), + ] + for args in sample_inputs: + opcheck(wait_event, args) + opcheck(record_event, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() + + def test_is_marked_side_effectful(self): + self.assertIn( + torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.join.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.wait_event.default, + torch.fx.node._side_effectful_functions, + ) + self.assertIn( + torch.ops.streams.record_event.default, + torch.fx.node._side_effectful_functions, + ) if __name__ == "__main__": diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 39a0dc628bae..5d31fa28880a 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -4036,7 +4036,7 @@ def backend(gm, args): @parametrize( "nt_view_name", - [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], + [k for k in VIEW_TEST_CASES if k != "subclass_dense_subclass_dense"], ) def test_inputs_to_compiled_fn_are_views(self, nt_view_name): self._input_view_test(nt_view_name) diff --git a/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory b/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron b/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977e..12f6ba2228db 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e5..cdc18b1d4c56 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 7949d2bb46cb..13277fccaea1 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.utils._pytree as pytree from torch._decomp import decomposition_table -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.testing import normalize_gm from torch._functorch._aot_autograd.descriptors import ( BufferAOTInput, @@ -48,17 +48,13 @@ def graph_capture(model, inputs, with_export): gm = model - fake_mode = None + tracing_context = None if with_export: - with ( - torch._dynamo.config.patch(install_free_tensors=True), - fx_traceback.preserve_node_meta(), - ): - # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)(*inputs) - fake_mode = gm.meta.get("fake_mode", None) - - with tracing(TracingContext(fake_mode)): + with fx_traceback.preserve_node_meta(): + gm = dynamo_graph_capture_for_export(model)(*inputs) + tracing_context = gm.meta.get("tracing_context", None) + + with tracing(tracing_context): with ExitStack() as stack: joint_with_descriptors = aot_export_joint_with_descriptors( stack, @@ -325,7 +321,7 @@ def forward(self, x, *, scale): inputs = (torch.randn(4, 3),) kwargs = {"scale": torch.tensor(2.0)} - gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs) + gm = dynamo_graph_capture_for_export(model)(*inputs, **kwargs) with ExitStack() as stack: # Export joint with descriptors @@ -356,8 +352,8 @@ def forward( primals, tangents, ): - primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight') - primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias') + primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight') + primals_2: "f32[2]" # ParamAOTInput(target='linear.bias') primals_3: "f32[4, 3]" # PlainAOTInput(idx=0) primals_4: "f32[]" # PlainAOTInput(idx=1) tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0)) @@ -379,8 +375,8 @@ def forward( transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None return pytree.tree_unflatten([ mul_2, # PlainAOTOutput(idx=0) - transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight')) - as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias')) + transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight')) + as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias')) None, # None None, # None ], self._out_spec)""", @@ -1063,9 +1059,11 @@ def forward(self, x): str(custom_metadata), """\ ('call_function', 'new_empty', {'pp_stage': 0}) +('get_attr', '_tensor_constant0', {'pp_stage': 0}) ('call_function', 'index_put', {'pp_stage': 0}) ('call_function', 'slice_2', {'pp_stage': 0}) ('call_function', 'slice_backward', {'pp_stage': 0}) +('get_attr', '_tensor_constant0_1', {'pp_stage': 0}) ('call_function', 'index', {'pp_stage': 0})""", ) @@ -1082,7 +1080,7 @@ def forward(self, x): model = SimpleLinear() inputs = (torch.randn(4, 3),) - gm = _dynamo_graph_capture_for_export(model)(*inputs) + gm = dynamo_graph_capture_for_export(model)(*inputs) fake_mode = gm.meta.get("fake_mode", None) with tracing(TracingContext(fake_mode)): diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index fba7a96288ca..6cae42d8929d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -167,6 +167,14 @@ def _pack_fp8_wrap(x): if not x.dtype.is_floating_point: return x + if type(x) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None + return (x.dtype, x.to(torch.float8_e5m2)) @@ -176,6 +184,13 @@ def _unpack_fp8_wrap(x): return x dtype, tensor = x + if type(tensor) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None return tensor.to(dtype) @@ -8111,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - xfail("narrow"), + skip("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 5034661fa3e0..f83f05966314 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -942,9 +942,7 @@ def false_fn(x): b = torch.randn(4, requires_grad=True) c = torch.randn(4, requires_grad=True) - for pred, fn in zip( - [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] - ): + for pred in [torch.tensor(False), torch.tensor(True)]: with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", @@ -3066,13 +3064,9 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs): ).to(DEVICE) # Test 3 models: RNNScanList, RNNScanTensor, RNNLoop - models = [ - ("ScanList", RNNScanList), - ("ScanTensor", RNNScanTensor), - ("Loop", RNNLoop), - ] + models = [RNNScanList, RNNScanTensor, RNNLoop] - for model_name, model_class in models: + for model_class in models: # Create uncompiled model model_uc = model_class().to(DEVICE) uncompiled_grads, uncompiled_loss = run_test_and_get_grads_loss( @@ -7538,7 +7532,7 @@ def foo(x): inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3)) for inp in inps: - gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4)) + gm = make_fx(foo, tracing_mode="symbolic")(inp) self.assertExpectedInline( gm.code.strip(), """\ diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index cab6b018d578..8efd8dfe398f 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -73,7 +73,7 @@ def parse_namespace(base): "sparse_": "sparse", "special_": "special", } - for heading in mappings.keys(): + for heading in mappings: if base.startswith(heading): return mappings[heading], base[len(heading) :] return None, base diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 5f37d8e1768d..fbb21633260e 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,8 +4,9 @@ import functools import unittest +from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch._dynamo @@ -15,6 +16,7 @@ import torch.fx.traceback as fx_traceback import torch.nn.functional as F from torch import nn +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable from torch._functorch.aot_autograd import aot_export_joint_with_descriptors from torch._subclasses.fake_tensor import FakeTensorMode @@ -51,24 +53,6 @@ def enable_local_map_wrapping(): yield -def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: - from torch._dynamo.functional_export import _dynamo_graph_capture_for_export - from torch.export._trace import _restore_state_dict - - """ - Thin wrapper around graph capture output that restores the - original calling convention and attribute fqn. TODO: - 1) Use bytecode for calling convention instead of pytree for more - seamless UX. - 2) Attach guards - 3) Be more careful about tensor constants names. - """ - with torch._dynamo.config.patch(install_free_tensors=True): - gm = _dynamo_graph_capture_for_export(model)(*inputs) - _restore_state_dict(model, gm) - return gm - - def ap_style_initial_capture( model: torch.nn.Module, inputs_fn: Callable ) -> torch.nn.Module: @@ -90,7 +74,7 @@ def ap_style_initial_capture( enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), ): - torch_ir_with_fqn = _export(model, inputs) + torch_ir_with_fqn = dynamo_graph_capture_for_export(model)(*inputs) unused = ExitStack() joint_with_descriptors = aot_export_joint_with_descriptors( unused, diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py new file mode 100644 index 000000000000..aef538854864 --- /dev/null +++ b/test/higher_order_ops/test_print.py @@ -0,0 +1,44 @@ +# Owner(s): ["module: higher order operators"] +import io +from unittest.mock import patch + +import torch +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestHopPrint(TestCase): + def test_base_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo") + x = x * x + torch._higher_order_ops.print("moo") + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo\nmoo") + + def test_para_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo {x} {y}", x=1, y=2) + x = x * x + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo 1 2") + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index bcb66beea700..aa4c3a1f229f 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, Generator, Sequence, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 50a389e8663f..7237d5a01c6b 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -4,6 +4,7 @@ import tempfile from threading import Event +import torch._inductor.config as config from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, @@ -16,9 +17,12 @@ class TestCompileWorker(TestCase): + def make_pool(self, size): + return SubprocPool(size) + @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) b = pool.submit(operator.sub, 100, 1) @@ -29,7 +33,7 @@ def test_basic_jobs(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_exception(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(raise_testexc) with self.assertRaisesRegex( @@ -42,7 +46,7 @@ def test_exception(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_crash(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: with self.assertRaises(Exception): a = pool.submit(os._exit, 1) @@ -58,7 +62,7 @@ def test_crash(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_quiesce(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) pool.quiesce() @@ -75,7 +79,7 @@ def test_logging(self): os.environ["ROLE_RANK"] = "0" with tempfile.NamedTemporaryFile(delete=True) as temp_log: os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name - pool = SubprocPool(2) + pool = self.make_pool(2) try: pool.submit(operator.add, 100, 1) self.assertEqual(os.path.exists(temp_log.name), True) @@ -83,6 +87,12 @@ def test_logging(self): pool.shutdown() +@config.patch("quiesce_async_compile_time", 0.1) +class TestCompileWorkerWithTimer(TestCompileWorker): + def make_pool(self, size): + return SubprocPool(size, quiesce=True) + + class TestTimer(TestCase): def test_basics(self): done = Event() diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index df93e7e1e4d6..ebee5149476b 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -320,7 +320,7 @@ def build_opt_kwarg_db(): continue if has_tensor_lr: - for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys(): + for scheduler_cls in LR_SCHEDULER_TO_KWARGS: name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}" compiled_opt_db.append( ( diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 4e1c48496ebc..ca520ab66bcc 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -2697,6 +2697,32 @@ def forward(self, x): self.common(mod, (u,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("bs", (5,)) + @parametrize("Mdim", (16,)) + @parametrize("Kdim", (32,)) + @parametrize("Ndim", (64,)) + @dtypes(torch.float) + def test_bmm_with_broadcasted_mat1(self, bs, Mdim, Kdim, Ndim, dtype): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w): + assert x.dim() == 2, f"Expected x to be 2D, got {x.dim()}D" + x_expanded = x.unsqueeze(0).expand(bs, -1, -1) + return x_expanded @ w + + counters.clear() + u = torch.randn(Mdim, Kdim).to(dtype=dtype) + v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype) + mod = M().to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (u, v), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index adc46a0f390a..c148c6946890 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -216,115 +216,6 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}" ) - @skipIfXpu - def test_mlp_custom_op_autotune(self): - """Test MLP autotuning with method parameter controlling different decomposition variants. - - Validates parametric tuning where the same decomposition function uses different - algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights). - """ - test_op_name = f"test_lib::mlp_{id(self)}" - - def mlp_variants( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - """MLP implementation with different computational approaches controlled by method parameter.""" - - if method == 0: - gate_proj = torch.matmul(input_tensor, gate_weight) - up_proj = torch.matmul(input_tensor, up_weight) - gated = torch.relu(gate_proj) * up_proj - return torch.matmul(gated, down_weight) - - elif method == 1: - batch_shape = input_tensor.shape[:-1] - hidden_dim = input_tensor.shape[-1] - output_dim = down_weight.shape[-1] - - input_2d = input_tensor.view(-1, hidden_dim) - - gate_proj = torch.mm(input_2d, gate_weight) - up_proj = torch.mm(input_2d, up_weight) - - gated = torch.relu(gate_proj) * up_proj - output_2d = torch.mm(gated, down_weight) - - return output_2d.view(*batch_shape, output_dim) - - @torch.library.custom_op(test_op_name, mutates_args=()) - def test_mlp_op( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - return mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=method - ) - - @test_mlp_op.register_fake - def _( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ): - return torch.empty( - input_tensor.shape[:-1] + (down_weight.shape[-1],), - device=input_tensor.device, - dtype=input_tensor.dtype, - ) - - # Use explicit config with method parameter as tuning knob - register_custom_op_autotuning( - test_mlp_op, - configs=[ - CustomOpConfig(method=0), - CustomOpConfig(method=1), - ], - name="test_mlp_autotuned", - input_gen_fns={ - "input_tensor": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.1, - "gate_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "up_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "down_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - }, - ) - - # Create test inputs - input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs() - - # Test that all method variants produce numerically equivalent results - expected = mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=0 - ) - - # Test autotuning - self._run_autotune_test( - test_mlp_op, - (input_tensor, gate_weight, up_weight, down_weight), - expected, - "MLP", - ) - def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] @@ -335,12 +226,12 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): @skipIfXpu def test_decompose_k_custom_op_autotune(self): - """Test decompose_k autotuning with parametric tuning for k_splits values. + """Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale). - Validates numerical parameter sweep where k_splits controls how the K dimension - is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]). + Validates that the custom op encapsulates the entire fused operation with parametric + tuning for k_splits values controlling how the K dimension is decomposed. """ - test_op_name = f"test_lib::decompose_k_{id(self)}" + test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}" def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 @@ -363,19 +254,23 @@ def decompose_k_implementation( return torch.sum(result, dim=0) # [m, n] @torch.library.custom_op(test_op_name, mutates_args=()) - def test_decompose_k_op( - a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 + def matmul_relu_epilogue_op( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: - """Matrix multiply with k-way decomposition - custom op using the decomposition.""" - return decompose_k_implementation(a, b, k_splits) - - @test_decompose_k_op.register_fake - def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): + """Matmul with decompose_k + bias + relu + scale (complete epilogue fusion).""" + matmul_result = decompose_k_implementation(a, b, k_splits) + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + @matmul_relu_epilogue_op.register_fake + def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) - # Register autotuning with different k_splits values using decomposition function + # Register autotuning with different k_splits values register_custom_op_autotuning( - test_decompose_k_op, + matmul_relu_epilogue_op, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), @@ -385,7 +280,7 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): CustomOpConfig(k_splits=64), CustomOpConfig(k_splits=128), ], - name="test_decompose_k_autotuned", + name="matmul_relu_epilogue_autotuned", input_gen_fns={ "a": lambda fake_tensor: torch.randn_like( fake_tensor, device=self.device @@ -395,12 +290,45 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): fake_tensor, device=self.device ) * 0.1, + "bias": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, }, ) + # Create test inputs a, b = self._create_decompose_k_inputs() - expected = a @ b - self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK") + bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1 + + # Compile the model using the custom op + @torch.compile + def test_model(a, b, bias): + return matmul_relu_epilogue_op(a, b, bias) + + torch._dynamo.reset() + + with config.patch( + max_autotune=True, + benchmark_fusion=True, + ): + compiled_result = test_model(a, b, bias) + + def reference_model(a, b, bias): + matmul_result = a @ b + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + expected = reference_model(a, b, bias) + + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + ) @skipIfXpu def test_multi_parameter_tuning(self): diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a5409..000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a1e5aa3cebc4..816d3b93ecfe 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5807,11 +5807,11 @@ def causal_mask(b, h, q_idx, kv_idx): from torch.utils._pytree import GetAttrKey - for key, tensor in tensors_with_keys: + for key, _tensor in tensors_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) - for key, value in context_with_keys: + for key, _value in context_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index f26a2347e4e8..f1067b8ffebb 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +from torch._C import FileCheck from torch._inductor import config, utils from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.test_case import run_tests, TestCase @@ -29,7 +30,6 @@ HAS_CPU, HAS_CUDA_AND_TRITON, ) -from torch.testing._internal.jit_utils import FileCheck from torch.utils._triton import has_triton_tma_device @@ -953,6 +953,240 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @torch._inductor.config.patch("emulate_precision_casts", True) + def test_mx_fusion(self): + # Register fake_scaled_mm custom op scoped to this test + with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib: + # Define the op schema + lib.define( + "fake_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, " + "Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, " + "bool use_fast_accum=False) -> Tensor" + ) + input_values = [] + + # Register CUDA implementation + @torch.library.impl(lib, "fake_scaled_mm", "CUDA") + def fake_scaled_mm_impl( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """Software-emulated scaled_mm for testing without CUDA 12.8""" + out_dtype = out_dtype or torch.bfloat16 + # just using add, because without real dtypes, + # was seeing overflow/instability + nonlocal input_values + input_values.append((mat_a, mat_b, scale_a, scale_b)) + result = mat_a.to(torch.float32) + mat_b.to(torch.float32) + if bias is not None: + result = result + bias.to(torch.float32) + return result.to(out_dtype) + + # Register fake implementation + @torch.library.impl(lib, "fake_scaled_mm", "Meta") + def fake_scaled_mm_meta( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """FakeTensor implementation""" + out_dtype = out_dtype or torch.bfloat16 + M, K = mat_a.shape + K2, N = mat_b.shape + torch._check( + K == K2, + lambda: f"Incompatible shapes: {mat_a.shape} @ {mat_b.shape}", + ) + return torch.empty((M, N), dtype=out_dtype, device=mat_a.device) + + def forward( + arg0_1, + arg1_1, + ): + view = torch.ops.aten.reshape.default(arg0_1, [8192, 256, 32]) + abs_1 = torch.ops.aten.abs.default(view) + amax = torch.ops.aten.amax.default(abs_1, [-1]) + unsqueeze = torch.ops.aten.unsqueeze.default(amax, -1) + view_1 = torch.ops.aten.view.dtype(unsqueeze, torch.int32) + bitwise_right_shift = torch.ops.aten.bitwise_right_shift.Tensor_Scalar( + view_1, 23 + ) + bitwise_and = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift, 255 + ) + sub = torch.ops.aten.sub.Tensor(bitwise_and, 127) + sub_1 = torch.ops.aten.sub.Tensor(sub, 8) + clamp_min = torch.ops.aten.clamp_min.default(sub_1, -127) + clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 128) + add = torch.ops.aten.add.Tensor(clamp_max, 127) + convert_element_type = torch.ops.prims.convert_element_type.default( + add, torch.uint8 + ) + isnan = torch.ops.aten.isnan.default(unsqueeze) + scalar_tensor = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where = torch.ops.aten.where.self( + isnan, scalar_tensor, convert_element_type + ) + convert_element_type_1 = torch.ops.prims.convert_element_type.default( + where, torch.int32 + ) + bitwise_left_shift = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_1, 23 + ) + view_2 = torch.ops.aten.view.dtype(bitwise_left_shift, torch.float32) + clamp_min_1 = torch.ops.aten.clamp_min.default( + view_2, 1.1754943508222875e-38 + ) + div = torch.ops.aten.div.Tensor(view, clamp_min_1) + clamp_min_2 = torch.ops.aten.clamp_min.default(div, -448.0) + clamp_max_1 = torch.ops.aten.clamp_max.default(clamp_min_2, 448.0) + convert_element_type_2 = torch.ops.prims.convert_element_type.default( + clamp_max_1, torch.float8_e4m3fn + ) + view_3 = torch.ops.aten.reshape.default( + convert_element_type_2, [8192, 8192] + ) + convert_element_type_2 = None + view_4 = torch.ops.aten.view.dtype(where, torch.float8_e8m0fnu) + squeeze = torch.ops.aten.squeeze.dim(view_4, -1) + + view_5 = torch.ops.aten.reshape.default(arg1_1, [8192, 256, 32]) + abs_2 = torch.ops.aten.abs.default(view_5) + amax_1 = torch.ops.aten.amax.default(abs_2, [-1]) + unsqueeze_1 = torch.ops.aten.unsqueeze.default(amax_1, -1) + view_6 = torch.ops.aten.view.dtype(unsqueeze_1, torch.int32) + bitwise_right_shift_1 = ( + torch.ops.aten.bitwise_right_shift.Tensor_Scalar(view_6, 23) + ) + bitwise_and_1 = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift_1, 255 + ) + sub_2 = torch.ops.aten.sub.Tensor(bitwise_and_1, 127) + sub_3 = torch.ops.aten.sub.Tensor(sub_2, 8) + clamp_min_3 = torch.ops.aten.clamp_min.default(sub_3, -127) + clamp_max_2 = torch.ops.aten.clamp_max.default(clamp_min_3, 128) + add_1 = torch.ops.aten.add.Tensor(clamp_max_2, 127) + convert_element_type_3 = torch.ops.prims.convert_element_type.default( + add_1, torch.uint8 + ) + isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1) + unsqueeze_1 = None + scalar_tensor_1 = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where_1 = torch.ops.aten.where.self( + isnan_1, scalar_tensor_1, convert_element_type_3 + ) + convert_element_type_4 = torch.ops.prims.convert_element_type.default( + where_1, torch.int32 + ) + bitwise_left_shift_1 = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_4, 23 + ) + convert_element_type_4 = None + view_7 = torch.ops.aten.view.dtype(bitwise_left_shift_1, torch.float32) + bitwise_left_shift_1 = None + clamp_min_4 = torch.ops.aten.clamp_min.default( + view_7, 1.1754943508222875e-38 + ) + div_1 = torch.ops.aten.div.Tensor(view_5, clamp_min_4) + clamp_min_5 = torch.ops.aten.clamp_min.default(div_1, -448.0) + clamp_max_3 = torch.ops.aten.clamp_max.default(clamp_min_5, 448.0) + convert_element_type_5 = torch.ops.prims.convert_element_type.default( + clamp_max_3, torch.float8_e4m3fn + ) + view_8 = torch.ops.aten.reshape.default( + convert_element_type_5, [8192, 8192] + ) + view_9 = torch.ops.aten.view.dtype(where_1, torch.float8_e8m0fnu) + squeeze_1 = torch.ops.aten.squeeze.dim(view_9, -1) + + permute = torch.ops.aten.permute.default(view_8, [1, 0]) + + view_13 = torch.ops.aten.reshape.default(squeeze, [64, 128, 64, 4]) + permute_2 = torch.ops.aten.permute.default(view_13, [0, 2, 1, 3]) + clone = torch.ops.aten.clone.default( + permute_2, memory_format=torch.contiguous_format + ) + view_14 = torch.ops.aten.reshape.default(clone, [4096, 4, 32, 4]) + permute_3 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3]) + clone_1 = torch.ops.aten.clone.default( + permute_3, memory_format=torch.contiguous_format + ) + view_15 = torch.ops.aten.reshape.default(clone_1, [4096, 32, 16]) + + view_16 = torch.ops.aten.reshape.default(view_15, [2097152]) + + view_18 = torch.ops.aten.reshape.default(squeeze_1, [64, 128, 64, 4]) + permute_5 = torch.ops.aten.permute.default(view_18, [0, 2, 1, 3]) + clone_2 = torch.ops.aten.clone.default( + permute_5, memory_format=torch.contiguous_format + ) + view_19 = torch.ops.aten.reshape.default(clone_2, [4096, 4, 32, 4]) + permute_6 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]) + clone_3 = torch.ops.aten.clone.default( + permute_6, memory_format=torch.contiguous_format + ) + view_20 = torch.ops.aten.reshape.default(clone_3, [4096, 32, 16]) + + view_21 = torch.ops.aten.reshape.default(view_20, [2097152]) + + _scaled_mm = torch.ops.test_fp8.fake_scaled_mm.default( + view_3, permute, view_16, view_21, None, None, torch.float32 + ) + return (_scaled_mm,) + + # Run with largest shape + M, K, N = 8192, 8192, 8192 + device = "cuda" + + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + f_c = torch.compile(fullgraph=True)(forward) + + _, code = run_and_get_code(f_c, A, B) + + FileCheck().check(".run(").check(".run(").check("fake_scaled_mm").run( + code[0] + ) + + for seed in range(5): + input_values.clear() + torch.manual_seed(seed) + # without dividing, outputs get way too large + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + + # Uses fake_scaled_mm custom op (no CUDA 12.8 needed!) + torch._dynamo.reset() + torch.compile(forward)(A, B) + + torch._dynamo.reset() + with config.patch({"loop_index_inversion_in_fusion": False}): + torch.compile(forward)(A, B) + + assert len(input_values) == 2 + for i in range(4): + self.assertEqual( + input_values[0][i], + input_values[1][i], + msg=f"idx {i} seed {seed}", + ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index ebe98373e622..63342502d3cd 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index c77b3574b222..051a5f590599 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -16,6 +16,7 @@ from torch._inductor import config as inductor_config, ir, metrics from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.graph import GraphLowering +from torch._inductor.invert_expr_analysis import generate_inverse_formula from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize @@ -1188,6 +1189,113 @@ def fn(nodes): torch.compile(f)(x) +class TestIndexInversion(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + gm = torch.fx.symbolic_trace(lambda: 0) + graph = GraphLowering(gm) + graph.scheduler = MockScheduler + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context(V.set_graph_handler(graph)) + + def _check_expr(self, expr, reconstruction, val_range): + import numpy as np + from sympy import lambdify + + assert len(expr.free_symbols) == 1 + p0 = next(iter(expr.free_symbols)) + + def floordiv_replacement(a, b): + """Replace FloorDiv(a, b) with a // b""" + return a // b + + def modularindexing_replacement(x, base, divisor): + """Replace ModularIndexing(x, base, divisor) with (x // base) % divisor""" + return (x // base) % divisor + + # Replace custom functions with sympy equivalents + expr_numpy_ready = expr.replace(FloorDiv, floordiv_replacement).replace( + ModularIndexing, modularindexing_replacement + ) + reconstruction_numpy_ready = reconstruction.replace( + FloorDiv, floordiv_replacement + ).replace(ModularIndexing, modularindexing_replacement) + + # Now lambdify with standard numpy + forward_func = lambdify(p0, expr_numpy_ready, modules="numpy") + inverse_func = lambdify(p0, reconstruction_numpy_ready, modules="numpy") + + test_values = np.arange(0, val_range, dtype=np.int64) + forward_values = forward_func(test_values).astype(np.int64) + recovered_values = inverse_func(forward_values).astype(np.int64) + torch.testing.assert_close(test_values, recovered_values) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._exit_stack.close() + + def test_original_complex_expression(self): + """Test the original motivating complex expression.""" + p0 = sympy.Symbol("p0") + expr = ( + 32768 * FloorDiv(p0, 32768) + + 8192 * FloorDiv(ModularIndexing(p0, 1, 16), 4) + + ModularIndexing(p0, 1, 4) + + 256 * ModularIndexing(p0, 16, 32) + + 4 * ModularIndexing(p0, 512, 64) + ) + + reconstruction = generate_inverse_formula(expr, p0) + self.assertIsNotNone(reconstruction) + self._check_expr(expr, reconstruction, 2097152) + + def test_inversion_cases(self): + """Test various expressions for correct inversion behavior.""" + p = sympy.Symbol("p") + + cases = [ + # (expression, should_be_invertible, test_range) + # Simple 2-term base-10 style: 10 = 1 Γ— 10 βœ“ + (10 * ModularIndexing(p, 10, 10) + ModularIndexing(p, 1, 10), True, 100), + # Simple 2-term base-2 style: 2 = 1 Γ— 2 βœ“ + (2 * ModularIndexing(p, 2, 2) + ModularIndexing(p, 1, 2), True, 4), + # 3-term decimal: 100 = 10Γ—10, 10 = 1Γ—10 βœ“ + ( + 100 * FloorDiv(p, 100) + + 10 * FloorDiv(ModularIndexing(p, 1, 100), 10) + + ModularIndexing(p, 1, 10), + True, + 1000, + ), + (4 * p, False, 64), # expr and inverse not bijections + # when sorted, invertible + (ModularIndexing(p, 1, 10) + 10 * ModularIndexing(p, 10, 10), True, None), + # Wrong coefficient ratios: 4 β‰  1Γ—2 + (4 * ModularIndexing(p, 1, 8) + ModularIndexing(p, 8, 2), False, None), + ( + 100 * FloorDiv(p, 100) + 7 * ModularIndexing(p, 1, 100), + False, + None, + ), # Wrong ratios + (FloorDiv(p, 100) + FloorDiv(p, 10) + p, False, None), # Overlapping ranges + (p**2 + 10 * p + 1, False, None), # Quadratic + (sympy.sin(p) + sympy.cos(p), False, None), # Trigonometric + ] + + for expr, should_invert, test_range in cases: + reconstruction = generate_inverse_formula(expr, p) + + if should_invert: + self.assertIsNotNone(reconstruction, f"Expected invertible: {expr}") + # Test correctness on sample values + self._check_expr(expr, reconstruction, test_range) + else: + self.assertIsNone(reconstruction, f"Expected non-invertible: {expr}") + + if __name__ == "__main__": if HAS_GPU: run_tests() diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 230a2514b917..0dcc37ee359d 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -117,6 +117,22 @@ def outer_red(): metrics.codegen_mix_order_reduction, ) + @inductor_config.patch(coordinate_descent_tuning=True) + def test_XBLOCK_coordest_tuning(self): + """ + We should skip XBLOCK coordinate descent tuning for + mix order reduction. + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=-1), x.sum(dim=0) + + x = torch.randn(32768, 256, dtype=torch.float, device=GPU_TYPE) + self.check_numeric(f, (x,)) + self.assertEqual(metrics.codegen_mix_order_reduction, 1) + @inductor_config.patch(unroll_reductions_threshold=1) def test_3layer_split_reduction(self): """ diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 1870a0e373be..c37f844e41ea 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from typing import Callable +from collections.abc import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c67bde87a369..5e599110d29d 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -500,8 +500,13 @@ def test_LinearAndSoftmax_codegen(self, bias=True): forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned + if bias: + # addmm -> mm + bias and bias is fused with softmax + softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)" + else: + softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)" self.assertTrue( - "tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper, + softmax_load_str in forward_wrapper, f"forward_wrapper: {forward_wrapper}", ) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py new file mode 100644 index 000000000000..2d4e6af002ab --- /dev/null +++ b/test/inductor/test_pallas.py @@ -0,0 +1,354 @@ +# Owner(s): ["oncall: pt2"] +import functools +import sys +import unittest + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.testing import make_test_cls_with_patches +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS +from torch.testing._internal.inductor_utils import HAS_PALLAS +from torch.utils._triton import has_triton + + +if IS_WINDOWS and IS_CI: + sys.stderr.write( + "Windows CI does not have necessary dependencies for test_torchinductor yet\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + + +try: + from . import test_torchinductor +except ImportError: + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + + +test_classes = {} + + +def make_pallas(cls): + """Create a test class variant that uses Pallas backend.""" + suffix = "_pallas" + cls_prefix = "Pallas" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "cuda_backend", "pallas"), + xfail_prop="_expected_failure_pallas", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +class PallasTests(TestCase): + """Basic tests for Pallas backend functionality.""" + + def test_simple_add(self): + """Test basic element-wise addition.""" + + def fn(a, b): + return a + b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_simple_mul(self): + """Test basic element-wise multiplication.""" + + def fn(a, b): + return a * b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_sin(self): + """Test sin operation.""" + + def fn(x): + return torch.sin(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_fused_ops(self): + """Test fused operations (sin + add).""" + + def fn(x, y): + return x.sin() + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + y = torch.randn(1024, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_exp_log(self): + """Test exp and log operations.""" + + def fn(x): + return torch.log(torch.exp(x)) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_sqrt(self): + """Test sqrt operation.""" + + def fn(x): + return torch.sqrt(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_tanh(self): + """Test tanh operation.""" + + def fn(x): + return torch.tanh(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_abs_neg(self): + """Test abs and neg operations.""" + + def fn(x): + return torch.abs(-x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_maximum_minimum(self): + """Test maximum and minimum operations.""" + + def fn(a, b): + return torch.maximum(a, b) + torch.minimum(a, b) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + @unittest.skipUnless(has_triton(), "requires triton") + @unittest.skip("Random ops not yet implemented in Pallas backend") + def test_random_consistency(self): + """Test that random number generation is consistent across backends.""" + seed = 1234 + shape = (3, 3) + dtype = torch.float32 + + for rand_fn in [ + functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), + functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), + ]: + + @torch.compile(backend="inductor", options={"cuda_backend": "pallas"}) + def get_rand_pallas(): + return rand_fn() + + @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) + def get_rand_triton(): + return rand_fn() + + torch.manual_seed(seed) + pallas_output = get_rand_pallas() + torch.manual_seed(seed) + triton_output = get_rand_triton() + + self.assertEqual(pallas_output, triton_output) + + def test_compile_options(self): + """Test that Pallas backend is properly configured.""" + + @torch.compile( + backend="inductor", + options={"cuda_backend": "pallas"}, + ) + def pallas_fn(a, b): + return a.sin() + b.cos() + + _, (code,) = run_and_get_code( + pallas_fn, + torch.randn(64, device="cuda"), + torch.randn(64, device="cuda"), + ) + # Verify Pallas-specific code generation + self.assertIn("import jax", code) + self.assertIn("import jax.numpy as jnp", code) + self.assertIn("from jax.experimental import pallas as pl", code) + + def test_2d_tensor(self): + """Test with 2D tensors (though current implementation flattens).""" + + def fn(x, y): + return x + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(32, 32, device="cuda") + y = torch.randn(32, 32, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_different_shapes(self): + """Test with different tensor shapes.""" + + def fn(x): + return x * 2.0 + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + for shape in [(64,), (128,), (256,), (1024,)]: + x = torch.randn(shape, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_contiguous_index_validation(self): + """Test that contiguous index validation works correctly end-to-end.""" + + # Test 1: Contiguous operations should work + def contiguous_add(a, b): + return a + b + + compiled = torch.compile( + contiguous_add, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = contiguous_add(a, b) + self.assertEqual(result, expected) + + # Test 2: Operations on contiguous tensors should work + def contiguous_mul(x): + return x * 2.0 + + compiled = torch.compile( + contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(128, 8, device="cuda") + result = compiled(x) + expected = contiguous_mul(x) + self.assertEqual(result, expected) + + # Test 3: Non-contiguous views will fail at runtime with JAX/Pallas + # This demonstrates that the Pallas backend requires contiguous memory layout + def operate_on_tensor(x): + return x.sin() + + compiled = torch.compile( + operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"} + ) + + # Create a transposed (non-contiguous) view + x = torch.randn(64, 32, device="cuda") + x_t = x.t() # Non-contiguous view + self.assertFalse(x_t.is_contiguous()) + + # This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack + # The error indicates that our contiguous-only approach is correct + with self.assertRaises((RuntimeError, Exception)) as cm: + result = compiled(x_t) + + # Verify the error is related to layout/contiguous issues + error_msg = str(cm.exception) + self.assertTrue( + "layout" in error_msg.lower() + or "contiguous" in error_msg.lower() + or "non-default" in error_msg.lower(), + f"Expected layout/contiguous error, got: {error_msg}", + ) + + # But if we make it contiguous first, it should work + x_t_contiguous = x_t.contiguous() + self.assertTrue(x_t_contiguous.is_contiguous()) + result = compiled(x_t_contiguous) + expected = operate_on_tensor(x_t_contiguous) + self.assertEqual(result, expected) + + +# Create test variants using the main test suite +# Note: Only enable GPU tests since Pallas primarily targets GPU +if test_torchinductor.HAS_GPU and HAS_PALLAS: + # Uncomment these to run full test suite with Pallas backend + # make_pallas(test_torchinductor.SweepInputsGPUTest) + # make_pallas(test_torchinductor.GPUTests) + pass + +if __name__ == "__main__": + if HAS_PALLAS: + run_tests(needs="filelock") diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 675d912c0c01..ed8993a1c9a3 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5876,6 +5876,22 @@ def fn(x, y): reference_in_float=False, ) + @skipIfMPS + def test_linalg_eig_stride_consistency(self): + def fn(x): + eigenvals, eigenvecs = torch.linalg.eig(x) + return eigenvecs + + x = torch.randn(5, 5, device=self.device, dtype=torch.float32) + + self.common( + fn, + [x], + exact_stride=True, + exact_dtype=True, + check_lowp=False, + ) + def test_view_as_complex(self): class Repro(torch.nn.Module): def __init__(self) -> None: @@ -15280,7 +15296,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_native_layer_norm_relu", + "triton_poi_fused_addmm_native_layer_norm", (torch.randn(4, 4, device=GPU_TYPE),), ), ] @@ -15293,7 +15309,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_ReLU", + "triton_poi_fused_LayerNorm_Linear_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 2244af38f635..e73f82ab6491 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -159,6 +159,9 @@ def run(*ex, **kwargs): # "test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_linalg_eig_stride_consistency_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu") + ), "test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 25fb60674e59..fc128ba61907 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -916,8 +916,7 @@ def judge(expected_event_count, prof): ) for key, count in expected_event_count.items(): self.assertTrue( - (key in actual_event_count.keys()) - and (count == actual_event_count[key]) + (key in actual_event_count) and (count == actual_event_count[key]) ) with _profile(use_kineto=kineto_available()) as prof: @@ -1406,10 +1405,7 @@ def test_profiler_fwd_bwd_link(self): s_ts_2 = flow_s_to_ts[2] f_ts_2 = flow_f_to_ts[2] self.assertTrue( - all( - ts in ts_to_name.keys() - for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2] - ) + all(ts in ts_to_name for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]) ) self.assertTrue( ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits" diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index c6316fe3cd7e..e8d28d7eff03 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -624,8 +624,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided @@ -671,8 +670,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index b2b2b402327a..f2cdbfd2d631 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -1840,7 +1840,7 @@ def test_cell_api(self, dtype): 'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic, 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): # fp16 dynamic quant is not supported for qnnpack or onednn kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype} @@ -1903,7 +1903,7 @@ def test_rnn_cell(self): 'RNNTanh': nnqr.RNNCell, 'RNNReLU': nnqr.RNNCell} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias} if rnn_type == 'RNNReLU': kwargs['nonlinearity'] = "relu" diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 9ea8d38828a6..93993fe33a49 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -650,7 +650,7 @@ def test_record_observer(self): observer_dict = {} _get_observer_dict(model, observer_dict) - self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), + self.assertTrue('fc1.module.activation_post_process' in observer_dict, 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index f69852760e8a..78e7799c864b 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -368,8 +368,8 @@ def _test_forward_per_tensor_cachemask_impl(self, device): float_types = (torch.float32, torch.float16, torch.float64, torch.bfloat16) torch_types = (torch.qint8, torch.quint8) Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2]) - tensor_qparam = (True, False) - for float_type, torch_type, X, tensor_qparams in itertools.product(float_types, torch_types, Xs, tensor_qparam): + tensor_qparams = (True, False) + for float_type, torch_type, X, tensor_qparam in itertools.product(float_types, torch_types, Xs, tensor_qparams): # pick the scale + zp so that some values get clipped X = X.to(float_type) obs = torch.ao.quantization.MinMaxObserver(torch_type) diff --git a/test/quantization/eager/test_bias_correction_eager.py b/test/quantization/eager/test_bias_correction_eager.py index 5f0c475f934d..071ea6e2a768 100644 --- a/test/quantization/eager/test_bias_correction_eager.py +++ b/test/quantization/eager/test_bias_correction_eager.py @@ -39,7 +39,7 @@ def correct_artificial_bias_quantize(self, float_model, img_data): torch.ao.quantization.convert(artificial_model, inplace=True) # manually changing bias - for name, submodule in artificial_model.named_modules(): + for submodule in artificial_model.modules(): if type(submodule) in _supported_modules: x = get_param(submodule, "bias") weight = get_param(submodule, "weight") diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cd922d94c60c..9c0526fde698 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,7 +204,8 @@ import operator import unittest import io -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): @@ -8806,7 +8807,7 @@ def forward(self, indices, offsets): # check it works in None and static qconfig for qconfig in [None, default_qconfig]: - qconfig_dict = {"": default_qconfig} + qconfig_dict = {"": qconfig} m = M().eval() m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ @@ -9662,10 +9663,10 @@ def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, .set_global(get_default_qat_qconfig(qengine)) \ .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) - train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] - eval_output = [[torch.randint(0, 10, (12, 1))]] + train_indices = [[torch.randint(0, 10, (12, 12), device=device), torch.randn((12, 1), device=device)] for _ in range(2)] + eval_output = [[torch.randint(0, 10, (12, 1), device=device)]] - model = EmbeddingBagLinear().train() + model = EmbeddingBagLinear().to(device).train() prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index dfd591cb9419..5b9aa34158b5 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2016,7 +2016,7 @@ def test_qat_conv2d_unary(self): } with override_quantized_engine("x86"): - for unary_op in unary_map.keys(): + for unary_op in unary_map: m = TestHelperModules.Conv2dUnaryModule( unary_map[unary_op][0], with_bn=True ) diff --git a/test/run_test.py b/test/run_test.py index 4b7030d46152..764b20dc9adc 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -73,7 +73,22 @@ ShardedTest, THRESHOLD, ) -from tools.testing.upload_artifacts import zip_and_upload_artifacts + + +try: + from tools.testing.upload_artifacts import ( + parse_xml_and_upload_json, + zip_and_upload_artifacts, + ) +except ImportError: + # some imports in those files might fail, e.g., boto3 not installed. These + # functions are only needed under specific circumstances (CI) so we can + # define dummy functions here. + def parse_xml_and_upload_json(): + pass + + def zip_and_upload_artifacts(failed: bool): + pass # Make sure to remove REPO_ROOT after import is done @@ -1672,7 +1687,7 @@ def get_selected_tests(options) -> list[str]: ] ) - if sys.version_info[:2] < (3, 13): + if sys.version_info[:2] < (3, 13) or sys.version_info[:2] >= (3, 14): # Skip tests for older Python versions as they may use syntax or features # not supported in those versions options.exclude.extend( @@ -1826,9 +1841,14 @@ def run_test_module( test_name = test.name # Printing the date here can help diagnose which tests are slow - print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") + start = time.perf_counter() + print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]") handler = CUSTOM_HANDLERS.get(test_name, run_test) return_code = handler(test, test_directory, options) + end = time.perf_counter() + print_to_stderr( + f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min" + ) assert isinstance(return_code, int) and not isinstance(return_code, bool), ( f"While running {str(test)} got non integer return code {return_code}" ) @@ -1882,6 +1902,7 @@ def run_tests( def handle_complete(failure: Optional[TestFailure]): failed = failure is not None if IS_CI and options.upload_artifacts_while_running: + parse_xml_and_upload_json() zip_and_upload_artifacts(failed) if not failed: return False diff --git a/test/test_as_strided.py b/test/test_as_strided.py new file mode 100644 index 000000000000..a5bcb8e27924 --- /dev/null +++ b/test/test_as_strided.py @@ -0,0 +1,176 @@ +# Owner(s): ["oncall: pt2"] + +from collections import deque +from typing import Optional + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]: + """Extract (sizes, strides) tuple from a tensor.""" + return (tuple(t.size()), tuple(t.stride())) + + +def enumerate_reachable_states( + initial_size: int, +) -> set[tuple[tuple[int, ...], tuple[int, ...]]]: + """ + Use BFS with DP to enumerate all reachable (size, stride) states from + a 1D contiguous tensor via valid view operations. + + We only explore states with offset=0 (you can retroactively change the offset). + We reject states with size=0 or size=1 dimensions as they are degenerate. + """ + # Create initial 1D contiguous tensor + initial_tensor = torch.arange(initial_size) + + initial_state = get_state(initial_tensor) + + # Map from state to tensor for that state + state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = { + initial_state: initial_tensor + } + visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state} + queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state]) + + while queue: + state = queue.popleft() + t = state_to_tensor[state] + sizes, strides = state + ndim = len(sizes) + + def add_state(new_t: torch.Tensor) -> None: + new_state = get_state(new_t) + sizes, strides = new_state + # Skip if has size-0 or size-1 dimensions + if any(s == 0 or s == 1 for s in sizes): + return + # Only accept states where strides are in descending order + if list(strides) != sorted(strides, reverse=True): + return + if new_state not in visited: + visited.add(new_state) + queue.append(new_state) + state_to_tensor[new_state] = new_t + + # 1. Unflatten: try factoring each dimension + for dim in range(ndim): + size = sizes[dim] + assert size > 1 + # Try all factorizations x * y = size where both x, y >= 2 + # We only need to check x up to size // 2 since when x > size // 2, + # y = size // x < 2, which we reject + for x in range(2, size // 2 + 1): + if size % x == 0: + y = size // x + add_state(t.unflatten(dim, (x, y))) + + # 2. Slice: exhaustively check all possible slicing parameters + for dim in range(ndim): + size = sizes[dim] + for start in range(size): + for stop in range(start + 1, size + 1): + for step in range(1, size + 1): + slices = [slice(None)] * ndim + slices[dim] = slice(start, stop, step) + add_state(t[tuple(slices)]) + + # 3. Flatten: merge adjacent dimensions + for dim in range(ndim - 1): + add_state(t.flatten(dim, dim + 1)) + + return visited + + +class TestAsStrided(TestCase): + def test_size_10_exhaustive(self) -> None: + """Test that size 10 produces exactly the expected 54 states.""" + expected_states = { + ((2,), (1,)), + ((2,), (2,)), + ((2,), (3,)), + ((2,), (4,)), + ((2,), (5,)), + ((2,), (6,)), + ((2,), (7,)), + ((2,), (8,)), + ((2,), (9,)), + ((2, 2), (2, 1)), + ((2, 2), (3, 1)), + ((2, 2), (3, 2)), + ((2, 2), (4, 1)), + ((2, 2), (4, 2)), + ((2, 2), (4, 3)), + ((2, 2), (5, 1)), + ((2, 2), (5, 2)), + ((2, 2), (5, 3)), + ((2, 2), (5, 4)), + ((2, 2), (6, 1)), + ((2, 2), (6, 2)), + ((2, 2), (6, 3)), + ((2, 2), (8, 1)), + ((2, 2, 2), (4, 2, 1)), + ((2, 2, 2), (5, 2, 1)), + ((2, 3), (3, 1)), + ((2, 3), (4, 1)), + ((2, 3), (5, 1)), + ((2, 3), (5, 2)), + ((2, 3), (6, 1)), + ((2, 4), (4, 1)), + ((2, 4), (5, 1)), + ((2, 5), (5, 1)), + ((3,), (1,)), + ((3,), (2,)), + ((3,), (3,)), + ((3,), (4,)), + ((3, 2), (2, 1)), + ((3, 2), (3, 1)), + ((3, 2), (3, 2)), + ((3, 2), (4, 1)), + ((3, 3), (3, 1)), + ((4,), (1,)), + ((4,), (2,)), + ((4,), (3,)), + ((4, 2), (2, 1)), + ((5,), (1,)), + ((5,), (2,)), + ((5, 2), (2, 1)), + ((6,), (1,)), + ((7,), (1,)), + ((8,), (1,)), + ((9,), (1,)), + ((10,), (1,)), + } + + actual_states = enumerate_reachable_states(10) + + self.assertEqual(len(actual_states), 54) + self.assertEqual(actual_states, expected_states) + + def test_subset_property(self) -> None: + """ + Test that for sizes 2..10, each smaller tensor results in a strict + subset of possible states compared to the next one. + """ + prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None + for size in range(2, 11): + current_states = enumerate_reachable_states(size) + + if prev_states is not None: + # Check that prev_states is a strict subset of current_states + self.assertTrue( + prev_states.issubset(current_states), + f"States from size {size - 1} are not a subset of size {size}", + ) + # Check that it's a strict subset (not equal) + self.assertTrue( + len(prev_states) < len(current_states), + f"States from size {size - 1} should be strictly fewer than size {size}", + ) + + prev_states = current_states + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_cuda.py b/test/test_cuda.py index a7e373da6382..dfbcdc1b4040 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1474,13 +1474,6 @@ def test_huge_index(self): res_cpu = src.cpu()[idx.cpu()] self.assertEqual(res.cpu(), res_cpu) - def test_fast_index_overflow(self): - src = torch.randint(0, 20, (4, 87, 1056, 736), device="cuda") - indices = torch.tensor([True, False, False, True], device="cuda") - res = src[indices] - res_cpu = src.cpu()[indices.cpu()] - self.assertEqual(res.cpu(), res_cpu) - def test_randint_randomness_for_large_range(self) -> None: # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox # offsets were not calculated correctly, resulting in reused random states. @@ -4633,6 +4626,52 @@ def check_output(script: str) -> str: rc = check_output(test_script) self.assertEqual(rc, "cudaMallocAsync") + def test_allocator_memory_fraction_setting(self): + def make_env(fraction): + env = os.environ.copy() + var = "PYTORCH_CUDA_ALLOC_CONF" + key = "per_process_memory_fraction" + value = [ + x + for x in env.get(var, "").split(",") + if len(x) > 0 and not x.startswith(f"{key}:") + ] + value.append(f"{key}:{fraction}") + env[var] = ",".join(value) + return env + + def run_test(value): + test_script = """\ +import os +import torch +device = torch._C._cuda_getDevice() +value = torch.cuda.memory.get_per_process_memory_fraction(device) +print(value, end="") + """ + return subprocess.run( + [sys.executable, "-c", test_script], + env=make_env(value), + text=True, + check=True, + capture_output=True, + ) + + self.assertEqual(run_test(0.0).stdout, "0.0") + self.assertEqual(run_test(0.5).stdout, "0.5") + self.assertEqual(run_test(1.0).stdout, "1.0") + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(-0.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(1.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + def test_cachingAllocator_raw_alloc(self): # Test that raw_alloc respects the setting that # activates/deactivates the caching allocator @@ -7413,6 +7452,140 @@ def test_graph_external_wait_and_record(self): ) +class TestFXMemoryProfiler(TestCase): + """Tests for memory profiler augmentation with original stack traces.""" + + def collect_frames( + self, augmented_snapshot, collect_device_traces=True, collect_segments=True + ): + """Collects all frames that has node metadata from a memory snapshot.""" + # Collect all frames with FX metadata + fx_frames = [] + + # Check device traces for FX debug fields + if collect_device_traces and "device_traces" in augmented_snapshot: + for trace_list in augmented_snapshot["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + for frame in trace_entry["frames"]: + if isinstance(frame, dict): + # Check for FX debug fields + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + + # Check segments/blocks for FX debug fields + if collect_segments and "segments" in augmented_snapshot: + for segment in augmented_snapshot["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + for frame in block["frames"]: + if isinstance(frame, dict): + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + return fx_frames + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_fx_memory_profiler_augmentation(self): + """Test that memory snapshots are augmented with FX debug information.""" + + # Create a simple model + class MLPModule(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + a = self.net1(x) + b = self.relu(a) + c = self.net2(b) + return c + + device = "cuda" + mod = MLPModule(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + torch.cuda.empty_cache() + + fx_frames = self.collect_frames(augmented_snapshot) + if TEST_WITH_ROCM: + self.assertGreater(len(fx_frames), 0) + else: + self.assertEqual(len(fx_frames), 12) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("a = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("c = self.net2(b)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("b = self.relu(a)", frame["fx_original_trace"]) + + # Test that when we have two graphs with the same src_code, they're not hashed + # to the same metadata + class MLPModule2(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + d = self.net1(x) + e = self.relu(d) + f = self.net2(e) + return f + + mod = MLPModule2(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + + # avoid collecting segments from previous run for unit test purpose + fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False) + self.assertGreater(len(fx_frames), 0) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("d = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("f = self.net2(e)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_parametrized_tests(TestCompileKernel) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 5a535e7e0066..cab86e42734f 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -1136,7 +1136,7 @@ def test_fork_iterdatapipe(self): ) break with warnings.catch_warnings(record=True) as wa: - for i, (n1, n2) in enumerate(zip(dp1, dp2)): + for n1, n2 in zip(dp1, dp2): output1.append(n1) output2.append(n2) self.assertEqual(len(wa), 1) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d50..d3f9e415ff94 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,70 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_narrow_with_tensor_start(self): + @torch.compile(backend="inductor", fullgraph=True) + def f(x, start, end): + return torch.narrow(x, 0, start, end) + + x = torch.tensor( + [False], device="cuda:0" if torch.cuda.is_available() else "cpu" + ) + start = torch.tensor(0) + res = f(x, start, 0) + self.assertEqual(res.shape, torch.Size([0])) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_fx.py b/test/test_fx.py index 880cc91edc06..3ad21e64c8ce 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -72,9 +72,16 @@ IS_WINDOWS, run_tests, skipIfTorchDynamo, + skipIfRocm, ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -771,6 +808,7 @@ def forward(self, a, b): gm = GraphModule(tracer.root, graph) expected = {1: 2, 2: 3, 3: 4, 4: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) # test custom codegen def transform_code(code): @@ -780,6 +818,7 @@ def transform_code(code): gm.recompile() expected = {2: 2, 3: 3, 4: 4, 5: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) def test_graph_unique_names_manual(self): graph: torch.fx.Graph = torch.fx.Graph() @@ -2032,6 +2071,31 @@ def forward(self, x): self.assertEqual(interpreter.run(input), gm(input)) self.assertEqual(interpreter.run(input), m(input)) + def test_interpreter_boxed_run_argument_validation(self): + class AddModule(torch.nn.Module): + def forward(self, lhs, rhs): + return lhs + rhs + + gm = torch.fx.symbolic_trace(AddModule()) + interpreter = Interpreter(gm) + + lhs = torch.tensor(1.0) + rhs = torch.tensor(2.0) + good_args = [lhs.clone(), rhs.clone()] + result = interpreter.boxed_run(good_args) + torch.testing.assert_close(result, lhs + rhs) + self.assertEqual(good_args, []) + + extra_args = [lhs.clone(), rhs.clone(), torch.tensor(3.0)] + with self.assertRaisesRegex(RuntimeError, "extra arguments"): + interpreter.boxed_run(extra_args) + self.assertEqual(len(extra_args), 3) + + missing_args = [lhs.clone()] + with self.assertRaisesRegex(RuntimeError, "missing arguments"): + interpreter.boxed_run(missing_args) + self.assertEqual(len(missing_args), 1) + def test_interpreter_other_graph(self): class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -4185,6 +4249,153 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch @@ -4535,7 +4746,7 @@ def check_symbols_have_bc_designation(m, seen): check_symbols_have_bc_designation(torch.fx.passes, set()) non_back_compat_strs = [ - torch.typename(obj) for obj in non_back_compat_objects.keys() + torch.typename(obj) for obj in non_back_compat_objects ] # Only want objects in torch.fx non_back_compat_strs = [ diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3018be817d9..8622d428cb4f 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1682,11 +1682,8 @@ def apply(fn): ] dtypes = ["int", "float", "bool"] values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} - devices = self.devices - for dtype_x, dtype_y, op, device in product( - dtypes, dtypes, binary_ops, devices - ): - code = ir_template.format(**locals()) + for dtype_x, dtype_y, op in product(dtypes, dtypes, binary_ops): + code = ir_template.format(dtype_x=dtype_x, dtype_y=dtype_y, op=op) # Interpret the graph try: @@ -1701,9 +1698,7 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError( - " ".join(["Compilation failed:", device, str(code)]) - ) from e + raise RuntimeError(" ".join(["Compilation failed:", str(code)])) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1713,9 +1708,7 @@ def apply(fn): self.assertEqual(ref, res) except Exception as e: raise RuntimeError( - " ".join( - ["Failed at runtime:", device, str(x), str(y), str(code)] - ) + " ".join(["Failed at runtime:", str(x), str(y), str(code)]) ) from e def test_matmul(self): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 5e54a851812e..a8e9be4c972a 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from typing import Callable +from collections.abc import Callable import torch @@ -359,6 +359,29 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(agrad, a.grad) self.assertEqual(bgrad, b.grad) + @onlyCUDA + @skipIfRocm + @dtypes(torch.half, torch.bfloat16) + @unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell") + @serialTest() + def test_cublas_batch_invariance_blackwell(self, device, dtype): + orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False) + with blas_library_context('cublaslt'): + N = 2048 + K = 6144 + M_max = 32 + x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16) + w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t() + full = x @ w + xx = x[:1] + out = xx @ w + self.assertEqual(full[:1], out, atol=0., rtol=0.) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) @@ -490,8 +513,6 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): - if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]: - self.skipTest("failed using hipblaslt on rocm 6.4.2") device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 diff --git a/test/test_mps.py b/test/test_mps.py index fad09c2f5eb2..cb0db4d96d33 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4470,6 +4470,14 @@ def test_bce_loss_broadcasts_weights(self): self.assertEqual(out1, out2) + def test_bce_backward_with_no_reduction_and_one_in_shape(self): + # Regression test for https://github.com/pytorch/pytorch/issues/166746 + output = torch.zeros(3, 2, 1, requires_grad=True, device='mps') + target = torch.zeros(3, 2, 1, device='mps') + torch.sum(nn.BCELoss(reduction='none')(output, target)).backward() + expected_grad = torch.zeros(3, 2, 1, device='mps') + self.assertEqual(output.grad, expected_grad) + def test_cross_entropy_loss(self): # Regression test for https://github.com/pytorch/pytorch/issues/116095 loss = nn.CrossEntropyLoss() diff --git a/test/test_nn.py b/test/test_nn.py index 034cf51d49ff..bedb4b22a01b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13516,7 +13516,7 @@ def compare_scaling(grads): # Should warning when parameters generator exhausted params = l.parameters() - for p in params: + for _p in params: pass with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/test/test_ops.py b/test/test_ops.py index 165b284b76d5..5f44a3ba0841 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2992,12 +2992,50 @@ def test_strided_layout(self, device, dtype, op): self.assertEqual(strided_result.layout, torch.strided) +class TestForwardADWithScalars(TestCase): + @ops( + [op for op in op_db if op.name in ["mul", "add", "div"]], + allowed_dtypes=(torch.float32,), + ) + def test_0d_tensor_with_python_scalar(self, device, dtype, op): + """Test that forward AD preserves dtype when combining 0D tensors with Python scalars.""" + if torch.float not in op.supported_backward_dtypes(device): + raise unittest.SkipTest("Does not support autograd") + + # skip if operator doesnt support forward AD + if not op.supports_forward_ad: + raise unittest.SkipTest("Does not support forward_ad") + + # create 0D tensors + primal0d = torch.ones((), device=device, dtype=dtype) + tangent0d = torch.ones((), device=device, dtype=dtype) + + with torch.autograd.forward_ad.dual_level(): + dual0d = torch.autograd.forward_ad.make_dual(primal0d, tangent0d) + + # Test with scalar on RHS + if op.supports_rhs_python_scalar: + result = op(dual0d, 2.0) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on RHS - dtype mismatch" + ) + # Test with scalar on LHS + if op.supports_one_python_scalar: + result = op(2.0, dual0d) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on LHS - dtype mismatch" + ) + + instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True) instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") instantiate_device_type_tests(TestFakeTensor, globals()) instantiate_device_type_tests(TestTags, globals()) +instantiate_device_type_tests(TestForwardADWithScalars, globals()) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b76895a0a91f..0487995a2d1c 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,7 +1987,6 @@ def f(t): } only_fake_tensor_failures = { - xfail('narrow'), xfail('tensor_split'), } diff --git a/test/test_pytree.py b/test/test_pytree.py index 7cc3b8affc0e..09cf0bbd47a4 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -601,6 +601,24 @@ def f(x, y, z): for case in cases: run_test(case) + @parametrize_pytree_module + def test_tree_map_dict_order(self, pytree): + d = {"b": 2, "a": 1, "c": 3} + od = OrderedDict([("b", 2), ("a", 1), ("c", 3)]) + dd = defaultdict(int, {"b": 2, "a": 1, "c": 3}) + for tree in (d, od, dd): + result = pytree.tree_map(lambda x: x, tree) + self.assertEqual( + list(result.keys()), + list(tree.keys()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + self.assertEqual( + list(result.values()), + list(tree.values()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + @parametrize_pytree_module def test_tree_map_only(self, pytree): self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]) diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 4d88ccd9cc7d..fd09afc11cec 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -209,42 +209,36 @@ def infer_scale_swizzle(mat, scale): ] == math.ceil(mat.shape[1] // 128): return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE + # if we're checking for nvfp4, need to adjust for packed-K + K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1 # NVFP4 if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4)) and mat.dtype == torch.float4_e2m1fn_x2 and scale.dtype == torch.float8_e4m3fn ): return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4 - # MXFP4 w/o swizzle - if ( - (scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]) - and mat.dtype == torch.float4_e2m1fn_x2 - and scale.dtype == torch.float8_e8m0fnu - ): - return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE - + # MX formats if not torch.version.hip: - # MXFP8 w/ swizzle + # MX w/swizzle (NVIDIA) if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4)) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4 else: - # MXFP8 w/o swizzle + # MX w/o swizzle (AMD) if ( - (scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]) + (scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1] + or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0]) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE @@ -1868,8 +1862,10 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") - @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) + @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: + if torch.version.hip and recipe == "nvfp4": + raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping") if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") @@ -1882,8 +1878,12 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0): raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping") - fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn - BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) + fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn + BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 + + if K % BLOCK_SIZE != 0: + raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping") + require_exact_match = True approx_match_sqnr_target = 22.0 @@ -2061,7 +2061,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 15 if torch.version.hip else 15.8 + approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8 C_ref = A_ref @ B_ref.t() diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index ba967c142f1e..96768f34affb 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -6,7 +6,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM) + (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM, serialTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, toleranceOverride, tol,) @@ -65,10 +65,12 @@ def test_gather(self, device, dtype): actual = torch.gather(src, 2, idx) self.assertEqual(actual, expected, atol=0, rtol=0) + @serialTest() @dtypes(torch.int8, torch.bfloat16) def test_gather_large(self, device, dtype): # test larger shapes to check vectorized implementation - for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)): + for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100), (4, 4, 16384 * 8192)): + torch.cuda.empty_cache() src = make_tensor((m, k), device=device, dtype=dtype) alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype) discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src) @@ -111,6 +113,8 @@ def test_gather_large(self, device, dtype): self.assertEqual(res_ind, ref, atol=0, rtol=0) res_gather = torch.gather(misaligned1, dim=dim, index=ind) self.assertEqual(res_gather, ref, atol=0, rtol=0) + del src, alloc0, alloc1, alloc2 + del discontig, misaligned, misaligned1 # test gather along 1st dim that can accidentally trigger fast path # because due to index dimension in the gather dim being 1 # an unexpected squashing in tensorIterator happens diff --git a/test/test_testing.py b/test/test_testing.py index c660eb83b804..09887be17c47 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -510,7 +510,7 @@ def test_trivial_passing_test(self, device): # Test without setting env var should run everything. env = dict(os.environ) for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]: - if k in env.keys(): + if k in env: del env[k] _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii')) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 3b864aae4f47..988bcf8de273 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -257,34 +257,6 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! βœ…") - @pytest.mark.xfail(reason="Issue #163971") - def test_fuzzer_issue_163971(self): - torch.manual_seed(0) - - def foo(arg0): - t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda - t1 = torch.softmax( - t0, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t2 = torch.nn.functional.gelu( - t1 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t3 = torch.softmax( - t2, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - output = t3 - return output - - arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True) - - out_eager = foo(arg0) - out_eager.sum().backward() - print("Eager Success! βœ…") - compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True) - out_compiled = compiled_foo(arg0) - out_compiled.sum().backward() - print("Compile Success! βœ…") - @pytest.mark.xfail(reason="Issue #164059") def test_fuzzer_issue_164059(self): torch.manual_seed(0) diff --git a/test/test_transformers.py b/test/test_transformers.py index 56e1365d33c4..ad7ae56307eb 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1914,6 +1914,7 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): q, k, v, None, 0.0, is_causal=True)) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): batch_size = 2**16 query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) @@ -1935,6 +1936,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) @@ -1948,6 +1950,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): @largeTensorTest("15GB", "cuda") @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_large_seq_len_uniform_attention(self): device = torch.device("cuda") dtype = torch.bfloat16 @@ -2854,7 +2857,7 @@ def test_cudnn_attention_broken_166211(self): # https://github.com/pytorch/pytorch/issues/166211#issue-3551350377 shape = (20, 4, 4, 32) scale = 10 - for i in range(100): + for _ in range(100): q = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale k = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale v = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index f0d6cc6fd9f9..14dda1cf3977 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,5 +1,5 @@ -from typing import Union -from typing_extensions import assert_type, TypeAlias +from typing import TypeAlias, Union +from typing_extensions import assert_type from torch import randn, Tensor diff --git a/third_party/tensorpipe b/third_party/tensorpipe index af0118d13e52..2b4cd91092d3 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit af0118d13e52f5a08841464a768e01a0bf3e3075 +Subproject commit 2b4cd91092d335a697416b2a3cb398283246849d diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b353d5d0d598..217cc8db6886 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") load( "@fbsource//xplat/caffe2/third_party:xnnpack_buck_shim.bzl", "LOGGING_SRCS", @@ -55,7 +55,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_headers = { "xnnpack.h": "XNNPACK/include/xnnpack.h", }, - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -70,7 +70,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = SUBGRAPH_SRCS + ["XNNPACK/src/datatype.c"], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -97,7 +97,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = TABLE_SRCS, headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -121,7 +121,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("scalar"), headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -147,7 +147,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -179,7 +179,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -211,7 +211,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -243,7 +243,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -275,7 +275,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -307,7 +307,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_ssse3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -339,7 +339,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -371,7 +371,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse41_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -403,7 +403,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -443,7 +443,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx", @@ -476,7 +476,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -531,7 +531,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnnigfni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -568,7 +568,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -625,7 +625,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -660,7 +660,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("avxvnni") if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavxvnni", @@ -697,7 +697,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avxvnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -729,7 +729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -770,7 +770,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_f16c_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mf16c", @@ -804,7 +804,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -853,7 +853,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fma3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mfma", @@ -894,7 +894,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -948,7 +948,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx2", @@ -994,7 +994,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1039,7 +1039,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1108,7 +1108,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1141,7 +1141,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1206,7 +1206,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512skx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1259,7 +1259,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -1301,7 +1301,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1350,7 +1350,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1378,7 +1378,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1430,7 +1430,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1460,7 +1460,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default @@ -1532,7 +1532,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1582,7 +1582,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1645,7 +1645,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1690,7 +1690,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1729,7 +1729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1774,7 +1774,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1815,7 +1815,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1860,7 +1860,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1900,7 +1900,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1959,7 +1959,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2004,7 +2004,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2053,7 +2053,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2088,7 +2088,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2114,7 +2114,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2138,7 +2138,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib_ovr_win32", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2165,7 +2165,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2193,7 +2193,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "armv7_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2209,7 +2209,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "prod_ukernels", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2234,7 +2234,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "XNNPACK", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, deps = [ ":tables", diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 13ca3e1389ac..4796153f24f0 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -763,6 +763,12 @@ """ ) +FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE = CodeTemplate( + """\ +update_wrapped_number(${inp_name}_tensor, ${inp_name}_t); +""" +) + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( """\ auto ${inp_name}_p = toNonOptPrimal(${inp}); @@ -1911,6 +1917,13 @@ def emit_fw_derivatives() -> list[str]: zeros_fn=zeros_fn, ) ) + if zeros_fn == "_efficientzerotensor_symint": + unpacked_arguments += ( + FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE.substitute( + inp_name=inp.name + ) + ) + if inp.name in (derivative.required_inputs_primal or []): unpacked_arguments += ( FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index b5802e803241..b2b0869d4835 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -2,6 +2,7 @@ import argparse import os +import re import sys import xml.etree.ElementTree as ET from multiprocessing import cpu_count, Pool @@ -19,17 +20,32 @@ ) +def should_upload_full_test_run(head_branch: str | None, head_repository: str) -> bool: + """Return True if we should upload the full test_run dataset. + + Rules: + - Only for the main repository (pytorch/pytorch) + - If head_branch is 'main', or a tag of form 'trunk/{40-hex-sha}' + """ + is_trunk_tag = bool(re.fullmatch(r"trunk/[0-9a-fA-F]{40}", (head_branch or ""))) + return head_repository == "pytorch/pytorch" and ( + head_branch == "main" or is_trunk_tag + ) + + def parse_xml_report( tag: str, report: Path, workflow_id: int, workflow_run_attempt: int, + job_id: int | None = None, ) -> list[dict[str, Any]]: """Convert a test report xml file into a JSON-serializable list of test cases.""" print(f"Parsing {tag}s for test report: {report}") - job_id = get_job_id(report) - print(f"Found job id: {job_id}") + if job_id is None: + job_id = get_job_id(report) + print(f"Found job id: {job_id}") test_cases: list[dict[str, Any]] = [] @@ -287,7 +303,8 @@ def init_value(test_case: dict[str, Any]) -> dict[str, Any]: remove_nan_inf(failed_tests_cases), ) - if args.head_branch == "main" and args.head_repository == "pytorch/pytorch": + # Upload full test_run only for trusted refs (main or trunk/{sha} tags) + if should_upload_full_test_run(args.head_branch, args.head_repository): # For jobs on main branch, upload everything. upload_workflow_stats_to_s3( args.workflow_run_id, diff --git a/tools/test/test_upload_gate.py b/tools/test/test_upload_gate.py new file mode 100644 index 000000000000..7d9a2e5fe3b0 --- /dev/null +++ b/tools/test/test_upload_gate.py @@ -0,0 +1,28 @@ +import unittest + +from tools.stats.upload_test_stats import should_upload_full_test_run + + +class TestUploadGate(unittest.TestCase): + def test_main_branch_on_pytorch_repo(self) -> None: + self.assertTrue(should_upload_full_test_run("main", "pytorch/pytorch")) + + def test_trunk_tag_valid_sha_on_pytorch_repo(self) -> None: + sha = "a" * 40 + self.assertTrue(should_upload_full_test_run(f"trunk/{sha}", "pytorch/pytorch")) + + def test_trunk_tag_invalid_sha_on_pytorch_repo(self) -> None: + # Not 40 hex chars + self.assertFalse(should_upload_full_test_run("trunk/12345", "pytorch/pytorch")) + + def test_non_main_branch_on_pytorch_repo(self) -> None: + self.assertFalse( + should_upload_full_test_run("feature-branch", "pytorch/pytorch") + ) + + def test_main_branch_on_fork_repo(self) -> None: + self.assertFalse(should_upload_full_test_run("main", "someone/fork")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index 07b62ec9a1b7..49d68fe9959a 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -1,11 +1,16 @@ import glob import gzip +import json import os import time import zipfile from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, Optional + +from filelock import FileLock, Timeout + +from tools.stats.upload_test_stats import parse_xml_report REPO_ROOT = Path(__file__).resolve().parent.parent.parent @@ -140,3 +145,66 @@ def trigger_upload_test_stats_intermediate_workflow() -> None: }, ) print(x.text) + + +def parse_xml_and_upload_json() -> None: + """ + Parse xml test reports that do not yet have a corresponding json report + uploaded to s3, and upload the json reports to s3. Use filelock to avoid + uploading the same file from multiple processes. + """ + try: + job_id: Optional[int] = int(os.environ.get("JOB_ID", 0)) + if job_id == 0: + job_id = None + except (ValueError, TypeError): + job_id = None + + try: + for xml_file in glob.glob( + f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True + ): + xml_path = Path(xml_file) + json_file = xml_path.with_suffix(".json") + lock = FileLock(str(json_file) + ".lock") + + try: + lock.acquire(timeout=0) # immediately fails if already locked + if json_file.exists(): + continue # already uploaded + test_cases = parse_xml_report( + "testcase", + xml_path, + int(os.environ.get("GITHUB_RUN_ID", "0")), + int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")), + job_id, + ) + line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases]) + + gzipped = gzip.compress(line_by_line_jsons.encode("utf-8")) + s3_key = ( + json_file.relative_to(REPO_ROOT / "test/test-reports") + .as_posix() + .replace("/", "_") + ) + + get_s3_resource().put_object( + Body=gzipped, + Bucket="gha-artifacts", + Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}", + ContentType="application/json", + ContentEncoding="gzip", + ) + + # We don't need to save the json file locally, but doing so lets us + # track which ones have been uploaded already. We could probably also + # check S3 + with open(json_file, "w") as f: + f.write(line_by_line_jsons) + except Timeout: + continue # another process is working on this file + finally: + if lock.is_locked: + lock.release() + except Exception as e: + print(f"Failed to parse and upload json test reports: {e}") diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 737362be62b4..b659be9ee119 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" +from collections.abc import Callable from datetime import timedelta from enum import Enum from typing import Any, Optional, overload, Union @@ -616,6 +617,11 @@ class FakeWork(Work): def wait(self, timeout: timedelta = ...) -> bool: ... def getFuture(self) -> Future: ... +class PythonCallbackWork(Work): + def __init__(self, callback: Callable[[timedelta], bool]) -> None: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def get_future(self) -> Future: ... + class ProcessGroupGloo(Backend): class Device: ... diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index c23240e13170..a35befcad392 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -5,6 +5,8 @@ from torch import Tensor # Defined in torch/csrc/functorch/init.cpp +def set_inplace_requires_grad_allowed(allowed: bool) -> None: ... +def get_inplace_requires_grad_allowed() -> bool: ... def _set_dynamic_layer_keys_included(included: bool) -> None: ... def get_unwrapped(tensor: Tensor) -> Tensor: ... def is_batchedtensor(tensor: Tensor) -> bool: ... diff --git a/torch/__init__.py b/torch/__init__.py index 05a34bdd9320..b64961a9c56f 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -303,8 +303,8 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: return nvidia_lib_paths + lib_paths -def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] - """Preloads cuda deps if they could not be found otherwise.""" +def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] + """Preloads cuda library if it could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed assert platform.system() == "Linux", "Should only be called on Linux" @@ -320,6 +320,39 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) +def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: + cuda_libs: dict[str, str] = { + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "nvjitlink": "libnvJitLink.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "cusparselt": "libcusparseLt.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + "nvshmem": "libnvshmem_host.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", + } + + # If error is passed, re-raise it if it's not about one of the abovementioned + # libraries + if err is not None and [ + lib for lib in cuda_libs.values() if lib.split(".", 1)[0] in err.args[0] + ]: + raise err + + # Otherwise, try to preload dependencies from site-packages + for lib_folder, lib_name in cuda_libs.items(): + _preload_cuda_lib(lib_folder, lib_name) + + # libnvToolsExt is Optional Dependency + _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + + # See Note [Global dependencies] def _load_global_deps() -> None: if platform.system() == "Windows": @@ -346,43 +379,15 @@ def _load_global_deps() -> None: # libtorch_global_deps.so always depends in cudart, check if its installed and loaded if "libcudart.so" not in _maps: return - # If all above-mentioned conditions are met, preload nvrtc and nvjitlink - _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") - _preload_cuda_deps("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]") - _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") + # If all above-mentioned conditions are met, preload CUDA dependencies + _preload_cuda_deps() except Exception: pass except OSError as err: - # Can only happen for wheel with cuda libs as PYPI deps + # Can happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is - cuda_libs: dict[str, str] = { - "cublas": "libcublas.so.*[0-9]", - "cudnn": "libcudnn.so.*[0-9]", - "cuda_nvrtc": "libnvrtc.so.*[0-9]", - "cuda_runtime": "libcudart.so.*[0-9]", - "cuda_cupti": "libcupti.so.*[0-9]", - "cufft": "libcufft.so.*[0-9]", - "curand": "libcurand.so.*[0-9]", - "nvjitlink": "libnvJitLink.so.*[0-9]", - "cusparse": "libcusparse.so.*[0-9]", - "cusparselt": "libcusparseLt.so.*[0-9]", - "cusolver": "libcusolver.so.*[0-9]", - "nccl": "libnccl.so.*[0-9]", - "nvshmem": "libnvshmem_host.so.*[0-9]", - "cufile": "libcufile.so.*[0-9]", - } - - is_cuda_lib_err = [ - lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] - ] - if not is_cuda_lib_err: - raise err - for lib_folder, lib_name in cuda_libs.items(): - _preload_cuda_deps(lib_folder, lib_name) - - # libnvToolsExt is Optional Dependency - _preload_cuda_deps("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + _preload_cuda_deps(err) ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 706ec1768cd3..1469ca478a38 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -146,7 +146,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: backends = [ name - for name in _BACKENDS.keys() + for name in _BACKENDS if name not in _COMPILER_FNS or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] ] diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5858a4584b3d..0c95408401c7 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -739,6 +739,12 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 875f640194e4..4439c7dc09ef 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1285,7 +1285,6 @@ def _compile( # in the case of normal and exception code paths convert_frame_box: Optional[ConvertFrameBox] = None, ) -> ConvertFrameReturn: - from torch._inductor.async_compile import async_compile_pool_manager from torch.fx.experimental.validator import ( BisectValidationException, ValidationException, @@ -1479,7 +1478,6 @@ def count_args(code: CodeType) -> int: with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), - async_compile_pool_manager(), chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e23e049e3bbb..222647eeae9a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,10 +39,11 @@ import unittest import warnings import weakref +from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 979950cf3bd1..16583b89201e 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,5 +1,6 @@ import weakref -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._dynamo.source import Source diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 77f5d6cb05a0..50a2667c12a2 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2587,7 +2587,7 @@ def update_used_symbols( real_script_obj ): flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] - for attr in flat_dict.keys(): + for attr in flat_dict: fake_attr_val = getattr( fake_script_obj.wrapped_obj, attr ) diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index a8dcf3e00c16..59f6f76317e6 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -10,7 +10,7 @@ import types from collections import OrderedDict -from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from itertools import repeat as _repeat from operator import eq, ne from typing import Any, TYPE_CHECKING @@ -276,7 +276,7 @@ def getattr_and_trace(*args, **kwargs): return fn(*args[2:], **kwargs) -def mapping_get(obj, key, value=None): +def mapping_get(obj, key, value=None, /): try: return obj.__getitem__(key) except KeyError: @@ -293,31 +293,45 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): return obj -# Used with something like dict(obj) -def construct_dict(cls, /, *args, **kwargs): - dst = cls.__new__(cls) - - if args: - src = args[0] - - if not isinstance(src, Iterable): - raise TypeError(f"{type(src)} object is not iterable") - - # Ensure that the overridden __iter__ method is invoked - if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): - for key in src: - # This will inline the __getitem__ of the src object - dst[key] = src[key] - else: - # likely a sequence like tuple of pairs - for key, value in src: - dst[key] = value +def mutable_mapping_update(self, data=(), /, **kwargs): + if isinstance(data, Mapping): + # Merge standard mapping with PyMapping_Items + for key, value in data.items(): + self[key] = value + # FIXME: Enabling the `elif`-branch below needs too many `VariableClass.call_obj_hasattr` changes. + # >>> class Foo: + # ... def __init__(self): + # ... self.keys = lambda: ['a', 'b', 'c'] # not required to be a method + # ... + # ... def __getitem__(self, key): + # ... return 0 + # ... + # >>> dict(Foo()) + # {'a': 0, 'b': 0, 'c': 0} + # + # > This is a rare case, so we comment it out for now. + # + # elif hasattr(data, "keys"): + # # Merge mapping-like object with PyMapping_Keys + PyObject_GetItem + # for key in data.keys(): + # self[key] = data[key] + else: + if not isinstance(data, Iterable): + raise TypeError(f"{type(data).__name__!r} object is not iterable") + # Likely a sequence of pairs + for key, value in data: + self[key] = value if kwargs: - for key in kwargs: - dst[key] = kwargs[key] + for key, value in kwargs.items(): + self[key] = value - return dst + +# Used with something like dict(obj) +def construct_dict(cls, data=(), /, **kwargs): + self = cls.__new__(cls) + mutable_mapping_update(self, data, **kwargs) + return self def foreach_map_fn(*args): diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index f9bdc0cce4a0..b4de3200e296 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -6,7 +6,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import TypeIs import torch.utils._pytree as python_pytree @@ -24,9 +24,15 @@ __all__: list[str] = [] +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + if python_pytree._cxx_pytree_dynamo_traceable: import optree import optree._C + import optree.utils import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 @@ -64,7 +70,7 @@ def _(*args: Any, **kwargs: Any) -> bool: del __func del __name - @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] def tree_is_leaf( tree: PyTree, /, @@ -79,7 +85,7 @@ def tree_is_leaf( return True return False - @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) + @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] def tree_iter( tree: PyTree, /, @@ -110,7 +116,7 @@ def tree_iter( __all__ += ["tree_iter"] - @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] def tree_leaves( tree: PyTree, /, @@ -451,7 +457,7 @@ def treespec_dict( dict, metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -507,7 +513,7 @@ def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: type(node), metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) # type: ignore[arg-type] @@ -557,7 +563,7 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: __all__ += ["tree_unflatten"] - @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map( func: Callable[..., Any], tree: PyTree, @@ -578,7 +584,7 @@ def tree_map( __all__ += ["tree_map"] - @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map_( func: Callable[..., Any], tree: PyTree, @@ -600,14 +606,47 @@ def tree_map_( __all__ += ["tree_map_"] - _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr] + _none_registration = optree.register_pytree_node.get(type(None)) + assert _none_registration is not None @substitute_in_graph( # type: ignore[arg-type] - _none_unflatten, + _none_registration.unflatten_func, can_constant_fold_through=True, skip_signature_check=True, ) - def none_unflatten(_: None, children: Iterable[Any], /) -> None: + def none_unflatten(_: None, children: Iterable[_T], /) -> None: if len(list(children)) != 0: raise ValueError("Expected no children.") return None + + with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_flatten( + dct: dict[_KT, _VT], / + ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, + ) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values)) + return d # type: ignore[return-value] diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index bd38e9295a05..688a05f26ae6 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -42,7 +42,7 @@ ) from .codegen import PyCodegen from .exc import SideEffectsError, unimplemented_v2 -from .source import GlobalSource, LocalCellSource, LocalSource, Source +from .source import GlobalSource, LocalCellSource, Source, TempLocalSource from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( AttributeMutation, @@ -704,7 +704,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: ) cg.extend_output(create_call_function(0, False)) cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + var.source = TempLocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: # pyrefly: ignore [bad-assignment] var.source = LocalCellSource(var.local_name) @@ -729,7 +729,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: # `add_cache` generates STORE and consumes TOS, but we never # cleared it. TODO move this call into `add_cache` cg.clear_tos() - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) elif isinstance(var, variables.AutogradFunctionContextVariable): unimplemented_v2( gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", @@ -764,7 +764,7 @@ def load_new_method() -> None: cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) for ctx, args in self.save_for_backward: cg(ctx.source) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 8edd8f7540e3..5be6b8ccbf41 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -151,6 +151,23 @@ def name(self) -> str: return f"L[{repr(self.local_name)}]" +@dataclasses.dataclass(frozen=True) +class TempLocalSource(Source): + # like LocalSource, but cannot be guarded on + local_name: str + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self) -> GuardSource: + return GuardSource.TEMP_LOCAL + + def name(self) -> str: + raise NotImplementedError( + "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub." + ) + + @dataclasses.dataclass(frozen=True) class SyntheticLocalSource(Source): local_name: str diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9d0d87c5f8a0..3943f90b0020 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -434,12 +434,15 @@ def resume_fn(self) -> ReenterWith: else: return ReenterWith(self.stack_index - 1) - def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None: + def exit( + self, tx: InstructionTranslatorBase, is_graph_break: bool + ) -> VariableTracker | None: assert self.with_context is not None if ( is_graph_break and self.with_context.exit_on_graph_break() ) or not is_graph_break: return self.with_context.exit(tx) # type: ignore[arg-type] + return None class SpeculationLogDivergence(AssertionError): @@ -3317,7 +3320,7 @@ def SET_ADD(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "add", [v], {}) + obj.call_method(self, "add", [v], {}) # type: ignore[arg-type] def SET_UPDATE(self, inst: Instruction) -> None: v = self.pop() @@ -3326,7 +3329,7 @@ def SET_UPDATE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] def LIST_APPEND(self, inst: Instruction) -> None: v = self.pop() @@ -3634,7 +3637,7 @@ def DICT_MERGE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg].realize() assert isinstance(obj, ConstDictVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] DICT_UPDATE = DICT_MERGE @@ -3860,7 +3863,7 @@ def enter_ctx( else: self.block_stack.append(BlockStackEntry(inst, target, len(self.stack))) - return ctx.enter(self) + return ctx.enter(self) # type: ignore[arg-type] @staticmethod def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn: diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9206f2598afc..3eeedfb65da2 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -87,6 +87,12 @@ def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def] return gm.graph, region_tracker # type: ignore[union-attr] +def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def] + backend = AotEagerAndRecordGraphs() + result = torch.compile(backend=backend)(fn)(*args, **kwargs) + return result, backend.graphs, backend.fw_graphs, backend.bw_graphs + + def collect_results( model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any ) -> list[Any]: diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 1817a5f3c7ed..0f198377605e 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2061,7 +2061,11 @@ def call_dir( return None def call_dict( - self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + self, + tx: "InstructionTranslator", + /, + *args: VariableTracker, + **kwargs: VariableTracker, ) -> VariableTracker: return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) @@ -2069,6 +2073,7 @@ def call_dict( def call_custom_dict( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: @@ -2093,6 +2098,7 @@ def call_custom_dict( def call_custom_dict_fromkeys( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 0502c58a7842..3f52c19ff0a9 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This file contains a collection of context manager classes used by Dynamo for tracking and managing various PyTorch runtime states during graph compilation. These context @@ -23,8 +21,9 @@ import inspect import sys import warnings +from collections.abc import Callable, Sequence, Sized from contextlib import ExitStack -from typing import TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard @@ -67,35 +66,43 @@ class ContextWrappingVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Optional[Any] = None, **kwargs: Any + ) -> None: super().__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values - def enter(self, tx): - self._call_func(tx, self.target_values) + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + if hasattr(self, "_call_func"): + self._call_func(tx, self.target_values) self.set_cleanup_hook(tx) return variables.ConstantVariable.create(None) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): + def set_cleanup_hook( + self, tx: "InstructionTranslator", fn: Optional[Callable[..., Any]] = None + ) -> None: if fn is None: - def fn(): - self._call_func(tx, self.initial_values) + def fn() -> None: + if hasattr(self, "_call_func"): + self._call_func(tx, self.initial_values) - self.cleanup_fn = fn + self.cleanup_fn: Optional[Callable[..., Any]] = fn tx.output.add_cleanup_hook(self.cleanup) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -103,18 +110,18 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.extend_output([codegen.create_load_const(val) for val in target_values]) codegen.extend_output(create_call_function(len(target_values), False)) - def module_name(self): + def module_name(self) -> str: raise NotImplementedError("module_name called on base") - def fn_name(self): + def fn_name(self) -> str: raise NotImplementedError("fn_name called on base") def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert len(args) == 1 assert isinstance( args[0], @@ -128,28 +135,27 @@ def call_function( if isinstance(args[0], NestedUserFunctionVariable): return WrappedNestedUserFunctionVariable(args[0], self) - - if isinstance(args[0], SkipFunctionVariable): + elif isinstance(args[0], SkipFunctionVariable): return WrappedSkipFunctionVariable(args[0], self) - - if isinstance(args[0], UserMethodVariable): + elif isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) - - if isinstance(args[0], UserFunctionVariable): + elif isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + else: + raise AssertionError("Unexpected arg type") - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return True - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True - def cleanup(self): + def cleanup(self) -> None: if self.cleanup_fn is not None: self.cleanup_fn() self.cleanup_fn = None - def cleanup_assert(self): + def cleanup_assert(self) -> None: assert self.cleanup_fn, "multiple exits?" self.cleanup() @@ -157,7 +163,7 @@ def cleanup_assert(self): class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are # python constants. Which might not always be the case here. - def __init__(self, cm_obj, **kwargs) -> None: + def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None: assert cm_obj is not None super().__init__( value=cm_obj, @@ -166,44 +172,46 @@ def __init__(self, cm_obj, **kwargs) -> None: ) self.cm_obj = cm_obj - def module_name(self): + def module_name(self) -> str: return self.cm_obj.__module__ - def fn_name(self): + def fn_name(self) -> str: return type(self.cm_obj).__name__ - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__enter__") return variables.UserMethodVariable( - self.cm_obj.__enter__.__func__, + self.cm_obj.__enter__.__func__, # type: ignore[attr-defined] self, source=source, ).call_function(tx, [], {}) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__exit__") x = variables.UserMethodVariable( - self.cm_obj.__exit__.__func__, + self.cm_obj.__exit__.__func__, # type: ignore[attr-defined] self, source=source, - ).call_function(tx, args, {}) + ).call_function(tx, list(args), {}) tx.active_generic_context_managers.pop() return x - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return False - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): - def __init__(self, ctx_manager_vt, mod): + def __init__(self, ctx_manager_vt: ContextWrappingVariable, mod: Any) -> None: self.cm_vt = ctx_manager_vt self.mod = mod # We don't call super().__init__() because we're delegating most methods to cm_vt - def enter(self, tx: "InstructionTranslator"): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: # Custom enter implementation with side effects self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() @@ -212,7 +220,9 @@ def enter(self, tx: "InstructionTranslator"): tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) return self.cm_vt.enter(tx) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Custom exit implementation with side effects x = self.cm_vt.exit(tx, *args) tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) @@ -220,7 +230,7 @@ def exit(self, tx: "InstructionTranslator", *args): return x # Forward all other method calls to self.cm_vt - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This will be called for any attribute not explicitly defined in this class return getattr(self.cm_vt, name) @@ -229,14 +239,16 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requires grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "GradInplaceRequiresGradCtxManagerVariable": return GradInplaceRequiresGradCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [enabled] = self.target_values self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() torch._C._functorch.set_inplace_requires_grad_allowed(enabled) @@ -254,7 +266,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -269,14 +283,16 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "TemporarilyPopInterpreterStackCtxManagerVariable": return TemporarilyPopInterpreterStackCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.saved = torch._C._functorch.pop_dynamic_layer_stack() self.set_cleanup_hook( tx, @@ -290,7 +306,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -309,10 +327,12 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a jvp # call from eager that calls the compiled function, as the jvp levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "JvpIncrementNestingCtxManagerVariable": var = JvpIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -320,7 +340,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() self.set_cleanup_hook( @@ -334,7 +354,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(jvp_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} @@ -346,14 +368,16 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable): """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "SetFwdGradEnabledContextManager": return SetFwdGradEnabledContextManager( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [mode] = self.target_values self.prev_state = torch._C._is_fwd_grad_enabled() torch._C._set_fwd_grad_enabled(mode) @@ -369,7 +393,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -383,17 +409,17 @@ def exit(self, tx: "InstructionTranslator", *args): class DualLevelContextManager(ContextWrappingVariable): """Represents torch.autograd.forward_ad.dual_level ctx manager""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create(tx: "InstructionTranslator", **kwargs: Any) -> "DualLevelContextManager": return DualLevelContextManager( target_values=None, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) self.new_level = torch.autograd.forward_ad.enter_dual_level() self.set_cleanup_hook( @@ -407,7 +433,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(self.new_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -426,10 +454,12 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a grad # call from eager that calls the compiled function, as the grad levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "GradIncrementNestingCtxManagerVariable": var = GradIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -437,7 +467,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) grad_level = torch._C._functorch._grad_increment_nesting() self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) @@ -449,7 +479,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(grad_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._grad_decrement_nesting, (), {} @@ -461,19 +493,29 @@ class CatchWarningsCtxManagerVariable(ContextWrappingVariable): """Delay a call to warnings.catch_warnings""" @staticmethod - def create(tx: "InstructionTranslator", catch_warnings_args): + def create( + tx: "InstructionTranslator", catch_warnings_args: dict[str, VariableTracker] + ) -> "CatchWarningsCtxManagerVariable": return CatchWarningsCtxManagerVariable( catch_warnings_args=catch_warnings_args, target_values=None, initial_values=None, ) - def __init__(self, catch_warnings_args, **kwargs) -> None: + def __init__( + self, + catch_warnings_args: dict[str, VariableTracker], + target_values: Optional[Any] = None, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: assert isinstance(catch_warnings_args, dict), catch_warnings_args - super().__init__(**kwargs) + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) self.catch_warnings_args = catch_warnings_args - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: kwargs = { k: v.as_python_constant() for k, v in self.catch_warnings_args.items() } @@ -481,7 +523,7 @@ def enter(self, tx): self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) return variables.ConstantVariable.create(ctx_val.__enter__()) - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) cg.foreach(self.catch_warnings_args.values()) keys = tuple(self.catch_warnings_args.keys()) @@ -496,10 +538,14 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a vmap # call from eager that calls the compiled function, as the vmap levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", + target_values: Sequence[VariableTracker], + **kwargs: Any, + ) -> "VmapIncrementNestingCtxManagerVariable": var = VmapIncrementNestingCtxManagerVariable( target_values=target_values, initial_values=None, @@ -507,7 +553,7 @@ def create(tx: "InstructionTranslator", target_values, **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) batch_size, randomness = self.target_values if isinstance(batch_size, variables.SymNodeVariable): @@ -527,7 +573,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(vmap_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -541,10 +589,15 @@ def exit(self, tx: "InstructionTranslator", *args): class GradModeVariable(ContextWrappingVariable): """represents torch.{no_grad,enable_grad,set_grad_mode}()""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): + def create( + tx: "InstructionTranslator", + target_value: Any, + initialized: bool = False, + **kwargs: Any, + ) -> "GradModeVariable": var = GradModeVariable( target_values=[target_value], initial_values=[torch.is_grad_enabled()], @@ -555,31 +608,37 @@ def create(tx: "InstructionTranslator", target_value, initialized=False, **kwarg return var def __init__( - self, target_values, initial_values=None, initialized=True, **kwargs + self, + target_values: Any, + initial_values: Optional[Sequence[bool]] = None, + initialized: bool = True, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self._call_func(tx, self.initial_values) return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # Coalesce grad mode mutations @@ -589,16 +648,18 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_grad_enabled(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "set_grad_enabled" class InferenceModeVariable(ContextWrappingVariable): @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Any, **kwargs: Any + ) -> "InferenceModeVariable": var = InferenceModeVariable( [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs ) @@ -606,9 +667,9 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[bool] = None, + **kwargs: Any, ) -> None: if initial_values is None: # This must be called here since function defaults are evaluated at import time @@ -616,9 +677,10 @@ def __init__( super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -626,8 +688,9 @@ def exit(self, tx: "InstructionTranslator", *args): (self.proxy,), {}, ) + return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: disabled_inference_mode_forcibly = False if ( torch._dynamo.config.fake_tensor_disable_inference_mode @@ -642,7 +705,7 @@ def enter(self, tx): else: ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) - def cleanup_hook(): + def cleanup_hook() -> None: if disabled_inference_mode_forcibly: torch._C._set_grad_enabled(prior) else: @@ -655,11 +718,12 @@ def cleanup_hook(): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "inference_mode" @@ -667,7 +731,9 @@ class CUDADeviceVariable(ContextWrappingVariable): """represents torch.cuda.device""" @staticmethod - def create(tx: "InstructionTranslator", device, **kwargs): + def create( + tx: "InstructionTranslator", device: Any, **kwargs: Any + ) -> "CUDADeviceVariable": var = CUDADeviceVariable( target_values=[torch.cuda._get_device_index(device, optional=True)], initial_values=None, @@ -677,16 +743,17 @@ def create(tx: "InstructionTranslator", device, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[Any] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -696,7 +763,7 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(False) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: prev_idx = torch.cuda._exchange_device(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) self.proxy = tx.output.create_node( @@ -705,21 +772,24 @@ def enter(self, tx): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.cuda" - def fn_name(self): + def fn_name(self) -> str: return "device" class TorchFunctionDisableVariable(ContextWrappingVariable): """represents whether torch function overrides are enabled or not""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "TorchFunctionDisableVariable": var = TorchFunctionDisableVariable( target_values=[], initial_values=[], @@ -728,10 +798,14 @@ def create(tx: "InstructionTranslator", **kwargs): return var def __init__( - self, target_values, initial_values=None, only_subclass=True, **kwargs + self, + target_values: Sized, + initial_values: Optional[Sized] = None, + only_subclass: bool = True, + **kwargs: Any, ) -> None: assert len(target_values) == 0 - assert len(initial_values) == 0 + assert initial_values is not None and len(initial_values) == 0 from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() @@ -748,10 +822,14 @@ def __init__( ) install_guard(self._guards_singleton) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): - if fn is None: + def set_cleanup_hook( + self, + tx: "InstructionTranslator", + cleanup_fn: Optional[Callable[..., Any]] = None, + ) -> None: + if cleanup_fn is None: - def fn(): + def cleanup_fn() -> None: tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( self.initial_torch_function_subclass_enabled ) @@ -760,19 +838,19 @@ def fn(): self.initial_torch_function_subclass_enabled ) - self.cleanup_fn = fn + self.cleanup_fn = cleanup_fn tx.output.add_cleanup_hook(self.cleanup) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sized) -> None: assert len(values) == 0 tx.symbolic_torch_function_state.torch_function_subclass_enabled = False if not self.only_subclass: tx.symbolic_torch_function_state.torch_function_mode_enabled = False - def module_name(self): + def module_name(self) -> str: return "torch._C" - def fn_name(self): + def fn_name(self) -> str: if self.only_subclass: return "DisableTorchFunctionSubclass" return "DisableTorchFunction" @@ -782,11 +860,14 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable): """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" _guards_singleton = Guard( - GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS + GlobalStateSource(), + GuardBuilder.DETERMINISTIC_ALGORITHMS, # type: ignore[arg-type] ) @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: bool, **kwargs: Any + ) -> "DeterministicAlgorithmsVariable": var = DeterministicAlgorithmsVariable( target_values=[target_value], initial_values=[torch.are_deterministic_algorithms_enabled()], @@ -796,16 +877,21 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[bool], + initial_values: Optional[Sequence[bool]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 value = values[0] tx.output.create_node( @@ -813,10 +899,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_deterministic_algorithms(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "use_deterministic_algorithms" @@ -824,7 +910,9 @@ class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): """represents torch.autograd.graph.disable_saved_tensors_hook.""" @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Optional[str], **kwargs: Any + ) -> "DisabledSavedTensorsHooksVariable": var = DisabledSavedTensorsHooksVariable( target_values=[target_value], initial_values=[ @@ -836,15 +924,22 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Optional[str]], + initial_values: Optional[Sequence[Optional[str]]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func( + self, tx: "InstructionTranslator", values: Sequence[Optional[str]] + ) -> None: assert len(values) == 1 value = values[0] if value is not None: @@ -865,16 +960,20 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._autograd._saved_tensors_hooks_enable() - def module_name(self): + def module_name(self) -> str: return "torch.autograd.graph" - def fn_name(self): + def fn_name(self) -> str: return "disable_saved_tensors_hooks" class AutocastModeVariable(ContextWrappingVariable): @staticmethod - def create(func, args, kwargs): + def create( + func: torch.amp.autocast_mode.autocast, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> "AutocastModeVariable": assert func in [ torch.amp.autocast_mode.autocast, torch.cuda.amp.autocast, @@ -905,30 +1004,37 @@ def create(func, args, kwargs): var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Any], + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", torch.amp._exit_autocast, (self.proxy,), {} ) return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: ctx = torch.amp._enter_autocast(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) self.proxy = tx.output.create_node( "call_function", torch.amp._enter_autocast, (*self.target_values,), {} ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.amp.autocast_mode" - def fn_name(self): + def fn_name(self) -> str: return "autocast" @@ -937,20 +1043,22 @@ class NullContextVariable(ContextWrappingVariable): This class represents Python contextlib.nullcontext. """ - def __init__(self, target_values=None, **kwargs) -> None: + def __init__(self, target_values: Optional[Any] = None, **kwargs: Any) -> None: super().__init__(target_values=target_values, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: none = variables.ConstantVariable.create(None) return self.target_values if self.target_values else none - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" @@ -963,22 +1071,24 @@ class ProfilerContextVariable(ContextWrappingVariable): than `None`, per implementation of the torch objects. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(target_values=None, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return self - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.profiler object escaped from compiled region", context=str(self), @@ -995,27 +1105,37 @@ class PreserveVersionContextVariable(ContextWrappingVariable): """ @staticmethod - def _create_lambda_from_tensors(tx, tensors): + def _create_lambda_from_tensors( + tx: "InstructionTranslator", + tensors: VariableTracker, + ) -> "PreserveVersionContextVariable": if isinstance(tensors, variables.TensorVariable): versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in [tensors]] ) - tensors = variables.TupleVariable([tensors]) + tensors_tuple = variables.TupleVariable([tensors]) else: + assert isinstance(tensors, variables.TupleVariable) versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in tensors.items] ) - return PreserveVersionContextVariable(tensors, versions) + tensors_tuple = tensors + return PreserveVersionContextVariable(tensors_tuple, versions) @staticmethod - def constructor(tx): + def constructor(tx: "InstructionTranslator") -> VariableTracker: return variables.LambdaVariable( lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( tx, tensors ) ) - def __init__(self, tensors, prev_versions, **kwargs) -> None: + def __init__( + self, + tensors: VariableTracker, + prev_versions: VariableTracker, + **kwargs: Any, + ) -> None: kwargs.setdefault("target_values", None) super().__init__(**kwargs) self.tensors = tensors @@ -1028,17 +1148,19 @@ def __init__(self, tensors, prev_versions, **kwargs) -> None: ): self.prev_versions = variables.TupleVariable([self.prev_versions]) - def enter(self, tx): - pass + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: from ..tensor_version_op import _unsafe_set_version_counter return variables.TorchInGraphFunctionVariable( _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1053,10 +1175,15 @@ def reconstruct(self, codegen: "PyCodegen"): class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): + def create( + tx: "InstructionTranslator", + param_group_var: Any, + target_value: Any, + **kwargs: Any, + ) -> "FSDPParamGroupUseTrainingStateVariable": var = FSDPParamGroupUseTrainingStateVariable( param_group_var=param_group_var, target_values=[target_value], @@ -1066,7 +1193,11 @@ def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs) return var def __init__( - self, param_group_var, target_values, initial_values=None, **kwargs + self, + param_group_var: Any, + target_values: Sequence[Any], + initial_values: Optional[Sequence[Any]] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1074,24 +1205,27 @@ def __init__( self.param_group_var = param_group_var install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): - self._call_func(tx, self.initial_values) + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): - self._call_func(tx, self.initial_values) # undo eager initialization + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # undo eager initialization + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[Any]) -> None: assert len(values) == 1 value = values[0] if self.param_group_var.value._training_state != value: @@ -1106,10 +1240,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) self.param_group_var.value._training_state = value - def module_name(self): + def module_name(self) -> str: return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" - def fn_name(self): + def fn_name(self) -> str: return "use_training_state" @@ -1117,7 +1251,12 @@ class SDPAKernelVariable(ContextWrappingVariable): """represents torch.nn.attention.sdpa_kernel""" @staticmethod - def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): + def create( + tx: "InstructionTranslator", + backends: Any, + set_priority: bool = False, + **kwargs: Any, + ) -> "SDPAKernelVariable": if isinstance(backends, torch.nn.attention.SDPBackend): backends = [backends] var = SDPAKernelVariable( @@ -1131,9 +1270,9 @@ def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): def __init__( self, target_values: list[torch.nn.attention.SDPBackend], - initial_values=None, + initial_values: Any = None, set_priority: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1141,7 +1280,10 @@ def __init__( self.set_priority = set_priority @staticmethod - def _backends_to_nodes(tx, backends): + def _backends_to_nodes( + tx: "InstructionTranslator", + backends: list[Any], + ) -> list[Any]: # convert to/from string in order to bake the backend into FX graph nodes = [ tx.output.create_node( @@ -1154,7 +1296,7 @@ def _backends_to_nodes(tx, backends): ] return nodes - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( with_priority=self.set_priority ) @@ -1176,7 +1318,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() arg = self._backends_to_nodes(tx, self.prev_backends) tx.output.create_node( @@ -1187,12 +1331,12 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.nn.attention" # use a private version of sdpa_kernel that accepts variadic arguments # since dynamo reconstructs the contents of target_values one-by-one - def fn_name(self): + def fn_name(self) -> str: return "_sdpa_kernel_variadic" @@ -1206,12 +1350,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable): __exit__ method (instead of tracing). """ - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Any = None, **kwargs: Any + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx, *args): + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Run the annotation ctx manager in eager. Also ensure that # preserve_node_meta context manager is setup. This is important to pass # on the metadata to the create_proxy nodes. @@ -1221,13 +1369,13 @@ def enter(self, tx, *args): self.set_cleanup_hook(tx, lambda: stack.close()) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.fx.traceback" - def fn_name(self): + def fn_name(self) -> str: return "annotate" - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.fx.traceback.annotate escaped from compiled region", context=str(self), @@ -1243,50 +1391,52 @@ class DynamoConfigPatchVariable(ContextWrappingVariable): # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness # (though it may affect tracing behavior) - def __init__(self, target_values, **kwargs) -> None: - target_values = tuple(target_values.items()) - super().__init__(target_values=(target_values,), initial_values=None, **kwargs) - self.initial_values = {} - for key, _ in target_values: - self.initial_values[key] = torch._dynamo.config.__getattr__(key) - self.initial_values = (tuple(self.initial_values.items()),) - - def _call_func(self, tx: "InstructionTranslator", values): + def __init__(self, target_values: dict[str, Any], **kwargs: Any) -> None: + target_values_tuple = tuple(target_values.items()) + super().__init__( + target_values=(target_values_tuple,), initial_values=None, **kwargs + ) + initial_values_dict = {} + for key, _ in target_values_tuple: + initial_values_dict[key] = torch._dynamo.config.__getattr__(key) # type: ignore[attr-defined] + self.initial_values = (tuple(initial_values_dict.items()),) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # manually patch dynamo config for key, val in value: - torch._dynamo.config.__setattr__(key, val) + torch._dynamo.config.__setattr__(key, val) # type: ignore[attr-defined] # No need to keep track of global side effects because # dynamo will properly restore this context manager for # unsupported instructions and continuation functions. # Dynamo config also should not affect the semantics of the compiled graph. - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "patch_dynamo_config" class ErrorOnGraphBreakVariable(ContextWrappingVariable): """represents torch._dynamo.error_on_graph_break""" - def __init__(self, error_on_graph_break, **kwargs) -> None: + def __init__(self, error_on_graph_break: bool, **kwargs: Any) -> None: super().__init__( target_values=(error_on_graph_break,), initial_values=(_get_error_on_graph_break(),), **kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 _set_error_on_graph_break(values[0]) - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "error_on_graph_break" @@ -1294,7 +1444,7 @@ class WithEnterFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.ctx = ctx @@ -1302,16 +1452,17 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not args assert not kwargs # NOTE: we assume that the instruction immediately after the current CALL instruction # is the first instruction of the block. + # pyrefly: ignore [bad-argument-type] return tx.enter_ctx(self.ctx, tx.current_instruction) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: try: type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" except NotImplementedError: @@ -1339,8 +1490,8 @@ class WithExitFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - target, - **kwargs, + target: Any, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance( @@ -1352,27 +1503,29 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. - self.ctx.reconstruct_type(codegen) + self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined] if codegen.tx.output.partial_convert: if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): codegen.append_output(create_instruction("SWAP", arg=2)) + # We rely on classes subtyping `GenericContextWrappingVariable` + # to implement these fns and have these attributes codegen.extend_output( - [codegen.create_load_const(val) for val in self.ctx.target_values] + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type] ) codegen.extend_output( - create_call_function(len(self.ctx.target_values), False) + create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type] ) codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP")) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 4f1f84a55b0b..fb212c332622 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Dictionary-related variable tracking classes for PyTorch Dynamo. @@ -26,7 +24,7 @@ import operator import types from collections.abc import Hashable as py_Hashable -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from torch._subclasses.fake_tensor import is_fake @@ -59,11 +57,13 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl -def was_instancecheck_override(obj): +def was_instancecheck_override(obj: Any) -> bool: return type(obj).__dict__.get("__instancecheck__", False) -def raise_unhashable(arg, tx=None): +def raise_unhashable( + arg: VariableTracker, tx: Optional["InstructionTranslator"] = None +) -> None: if tx is None: from torch._dynamo.symbolic_convert import InstructionTranslator @@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None): ) -def is_hashable(x): +def is_hashable(x: VariableTracker) -> bool: # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at # the underlying value without realizing the VT. Consider updating the @@ -143,7 +143,7 @@ class _HashableTracker: Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing """ - def __init__(self, vt) -> None: + def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) # TODO Temporarily remove to figure out what keys are we breaking on @@ -153,7 +153,7 @@ def __init__(self, vt) -> None: self.vt = vt @property - def underlying_value(self): + def underlying_value(self) -> Any: if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() @@ -178,7 +178,8 @@ def underlying_value(self): elif isinstance(self.vt, variables.FrozenDataClassVariable): Hashable = ConstDictVariable._HashableTracker fields_values = { - k: Hashable(v).underlying_value for k, v in self.vt.fields.items() + k: Hashable(v).underlying_value + for k, v in self.vt.fields.items() # type: ignore[attr-defined] } return variables.FrozenDataClassVariable.HashWrapper( self.vt.python_type(), fields_values @@ -187,16 +188,16 @@ def underlying_value(self): # The re module in Python 3.13+ has a dictionary (_cache2) with # an object as key (`class _ZeroSentinel(int): ...`): # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual - return self.vt.value + return self.vt.value # type: ignore[attr-defined,union-attr] else: x = self.vt.as_python_constant() return x - def __hash__(self): + def __hash__(self) -> int: return hash(self.underlying_value) @staticmethod - def _eq_impl(a, b): + def _eq_impl(a: Any, b: Any) -> bool: # TODO: Put this in utils and share it between variables/builtin.py and here type_a, type_b = type(a), type(b) if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): @@ -212,7 +213,7 @@ def _eq_impl(a, b): else: return a == b - def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: + def __eq__(self, other: object) -> bool: Hashable = ConstDictVariable._HashableTracker assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( type(other) @@ -226,8 +227,8 @@ def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: def __init__( self, items: dict[VariableTracker, VariableTracker], - user_cls=dict, - **kwargs, + user_cls: type = dict, + **kwargs: Any, ) -> None: # .clone() pass these arguments in kwargs but they're recreated a few # lines below @@ -247,18 +248,22 @@ def __init__( for x, v in items.items() ) - def make_hashable(key): + def make_hashable( + key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], + ) -> "ConstDictVariable._HashableTracker": return key if isinstance(key, Hashable) else Hashable(key) dict_cls = self._get_dict_cls_from_user_cls(user_cls) self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed - self.should_reconstruct_all = not is_from_local_source(self.source) + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) self.original_items = items.copy() self.user_cls = user_cls - def _get_dict_cls_from_user_cls(self, user_cls): + def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) # avoid executing user code if user_cls is a dict subclass @@ -277,10 +282,10 @@ def _get_dict_cls_from_user_cls(self, user_cls): dict_cls = dict return dict_cls - def as_proxy(self): + def as_proxy(self) -> dict[Any, Any]: return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} - def debug_repr(self): + def debug_repr(self) -> str: return ( "{" + ", ".join( @@ -289,20 +294,20 @@ def debug_repr(self): + "}" ) - def as_python_constant(self): + def as_python_constant(self) -> dict[Any, Any]: return { k.vt.as_python_constant(): v.as_python_constant() for k, v in self.items.items() } - def keys_as_python_constant(self): + def keys_as_python_constant(self) -> dict[Any, VariableTracker]: self.install_dict_keys_match_guard() return {k.vt.as_python_constant(): v for k, v in self.items.items()} - def python_type(self): + def python_type(self) -> type: return self.user_cls - def __contains__(self, vt) -> bool: + def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( @@ -322,13 +327,15 @@ def has_new_items(self) -> bool: for key, value in self.items.items() ) - def is_new_item(self, value, other): + def is_new_item( + self, value: Optional[VariableTracker], other: VariableTracker + ) -> bool: # compare the id of the realized values if both values are not lazy VTs if value and value.is_realized() and other.is_realized(): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct_kvs_into_new_dict(self, codegen): + def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: # Build a dictionary that contains the keys and values. num_args = 0 for key, value in self.items.items(): @@ -340,7 +347,7 @@ def reconstruct_kvs_into_new_dict(self, codegen): num_args += 1 codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: if self.user_cls is collections.OrderedDict: # emit `OrderedDict(constructed_dict)` codegen.add_push_null( @@ -358,19 +365,21 @@ def reconstruct(self, codegen: "PyCodegen"): def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker - ): + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: raise_observed_exception(KeyError, tx) return self.items[key] - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: - msg = f"Dictionary key {arg.value} not found during tracing" + msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] unimplemented_v2( gb_type="key not found in dict", - context=f"Key {arg.value}", + context=f"Key {arg.value}", # type: ignore[attr-defined] explanation=msg, hints=[ "Check if the key exists in the dictionary before accessing it.", @@ -379,13 +388,13 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): ) return self.items[key] - def maybe_getitem_const(self, arg: VariableTracker): + def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: return None return self.items[key] - def realize_key_vt(self, arg: VariableTracker): + def realize_key_vt(self, arg: VariableTracker) -> None: # Realize the LazyVT on a particular index assert arg in self key = ConstDictVariable._HashableTracker(arg) @@ -394,11 +403,13 @@ def realize_key_vt(self, arg: VariableTracker): if isinstance(original_key_vt, variables.LazyVariableTracker): original_key_vt.realize() - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: if self.source: install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Key guarding - These are the cases to consider # 1) The dict has been mutated. In this case, we would have already # inserted a DICT_KEYS_MATCH guard, so we can skip. @@ -439,11 +450,11 @@ def install_dict_contains_guard(self, tx, args): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # NB - Both key and value are LazyVariableTrackers in the beginning. So, # we have to insert guards when a dict method is accessed. For this to # be simple, we are conservative and overguard. We skip guard only for @@ -462,7 +473,7 @@ def call_method( tx, *args, **kwargs ) tx.output.side_effects.mutation(self) - self.items.update(temp_dict_vt.items) + self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "__getitem__": # Key guarding - Nothing to do. LazyVT for value will take care. @@ -526,7 +537,7 @@ def call_method( return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) != 2: @@ -550,7 +561,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: self.install_dict_contains_guard(tx, args) @@ -565,7 +576,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: # missing item, return the default value. Install no DICT_CONTAINS guard. @@ -599,7 +610,7 @@ def call_method( last = v.value else: raise_args_mismatch(tx, name) - k, v = self.items.popitem(last=last) + k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] else: k, v = self.items.popitem() @@ -632,17 +643,17 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() - dict_vt = args[0] + dict_vt: ConstDictVariable = args[0] else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) - self.items.update(dict_vt.items) + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + self.items.update(dict_vt.items) # type: ignore[attr-defined] if has_kwargs: # Handle kwargs - kwargs = { + kwargs_hashable = { Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() } - self.items.update(kwargs) + self.items.update(kwargs_hashable) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) @@ -656,7 +667,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_contains_guard(tx, args) contains = args[0] in self @@ -671,7 +682,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) > 2: @@ -707,7 +718,7 @@ def call_method( and "last" in kwargs and isinstance(kwargs["last"], ConstantVariable) ): - last = kwargs.get("last").value + last = kwargs.get("last").value # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) @@ -723,7 +734,7 @@ def call_method( ) elif name == "__ne__": return ConstantVariable.create( - not self.call_method(tx, "__eq__", args, kwargs).value + not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] ) elif name == "__or__": if len(args) != 1: @@ -750,14 +761,14 @@ def call_method( if not istype( other, (ConstDictVariable, variables.UserDefinedDictVariable) ): - msg = ( + err_msg = ( f"unsupported operand type(s) for |: '{self.python_type().__name__}'" f"and '{other.python_type().__name__}'" ) - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception(TypeError, tx, args=[err_msg]) # OrderedDict overloads __ror__ - ts = {self.user_cls, other.user_cls} + ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] user_cls = ( collections.OrderedDict if any(issubclass(t, collections.OrderedDict) for t in ts) @@ -774,8 +785,8 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. - args[0].install_dict_keys_match_guard() - new_dict_vt.items.update(args[0].items) + args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] + new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] return new_dict_vt elif name == "__ior__": self.call_method(tx, "update", args, kwargs) @@ -789,11 +800,13 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: self.install_dict_keys_match_guard() return [x.vt for x in self.items.keys()] - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( @@ -816,25 +829,25 @@ def call_obj_hasattr(self, tx, name): ], ) - def clone(self, **kwargs): + def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict - def python_type(self): + def python_type(self) -> type: return types.MappingProxyType - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # load types.MappingProxyType if self.source: msg = ( @@ -863,11 +876,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.source and tx.output.side_effects.has_existing_dict_mutation(): msg = ( "A dict has been modified while we have an existing mappingproxy object. " @@ -892,7 +905,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) @@ -900,45 +913,62 @@ def call_obj_hasattr( class NNModuleHooksDictVariable(ConstDictVariable): # Special class to avoid adding any guards on the nn module hook ids. - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: pass class DefaultDictVariable(ConstDictVariable): - def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type, + default_factory: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict + if default_factory is None: + default_factory = ConstantVariable.create(None) self.default_factory = default_factory - def is_python_constant(self): + def is_python_constant(self) -> bool: # Return false for unsupported defaults. This ensures that a bad handler # path is not taken in BuiltinVariable for getitem. if self.default_factory not in [list, tuple, dict] and not self.items: return False return super().is_python_constant() - def debug_repr(self): + def debug_repr(self) -> str: + assert self.default_factory is not None return ( f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" ) @staticmethod - def is_supported_arg(arg): + def is_supported_arg(arg: VariableTracker) -> bool: if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: - return isinstance(arg, variables.functions.BaseUserFunctionVariable) + return isinstance( + arg, + ( + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + ), + ) def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") @@ -946,18 +976,21 @@ def call_method( if args[0] in self: return self.getitem_const(tx, args[0]) else: - if self.default_factory is None: - raise KeyError(f"{args[0]}") + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[args[0]]) else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( - tx, "__setitem__", (args[0], default_var), kwargs + tx, "__setitem__", [args[0], default_var], kwargs ) return default_var else: return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: # emit `defaultdict(default_factory, new_dict)` codegen.add_push_null( lambda: codegen.extend_output( @@ -983,40 +1016,48 @@ class SetVariable(ConstDictVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: + # pyrefly: ignore[bad-assignment] items = dict.fromkeys(items, SetVariable._default_value()) + # pyrefly: ignore[bad-argument-type] super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "set()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return set(self.items.keys()) @staticmethod - def _default_value(): + def _default_value() -> VariableTracker: # Variable to fill in he keys of the dictionary return ConstantVariable.create(None) - def as_proxy(self): + def as_proxy(self) -> Any: return {k.vt.as_proxy() for k in self.set_items} - def python_type(self): + def python_type(self) -> type: return set - def as_python_constant(self): + def as_python_constant(self) -> Any: return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) - def _fast_set_method(self, tx, fn, args, kwargs): + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: try: res = fn( *[x.as_python_constant() for x in [self, *args]], @@ -1026,15 +1067,16 @@ def _fast_set_method(self, tx, fn, args, kwargs): raise_observed_exception( type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) ) + # pyrefly: ignore[unbound-name] return VariableTracker.build(tx, res) def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: # We forward the calls to the dictionary model from ..utils import check_constant_args @@ -1054,10 +1096,10 @@ def call_method( return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) if name == "__init__": - temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) tx.output.side_effects.mutation(self) self.items.clear() - self.items.update(temp_set_vt.items) + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "add": if kwargs or len(args) != 1: @@ -1068,7 +1110,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) name = "__setitem__" - args = (args[0], SetVariable._default_value()) + args = [args[0], SetVariable._default_value()] elif name == "pop": if kwargs or args: raise_args_mismatch( @@ -1079,12 +1121,14 @@ def call_method( ) # Choose an item at random and pop it via the Dict.pop method try: - result = self.set_items.pop().vt + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] except KeyError as e: raise_observed_exception( KeyError, tx, args=list(map(ConstantVariable.create, e.args)) ) - super().call_method(tx, name, (result,), kwargs) + # pyrefly: ignore[unbound-name] + super().call_method(tx, name, [result], kwargs) + # pyrefly: ignore[unbound-name] return result elif name == "isdisjoint": if kwargs or len(args) != 1: @@ -1206,6 +1250,7 @@ def call_method( f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" ) raise_observed_exception(TypeError, tx, args=[msg]) + assert m is not None return self.call_method(tx, m, args, kwargs) elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): @@ -1219,29 +1264,34 @@ def call_method( "__ixor__": "symmetric_difference_update", "__isub__": "difference_update", }.get(name) + assert m is not None self.call_method(tx, m, args, kwargs) return self elif name == "__eq__": if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(False) r = self.call_method(tx, "symmetric_difference", args, kwargs) - return ConstantVariable.create(len(r.set_items) == 0) + return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] elif name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: raise RuntimeError("Illegal to getitem on a set") - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: super().install_dict_contains_guard(tx, args) @@ -1249,27 +1299,27 @@ class FrozensetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "frozenset()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return self.items.keys() - def python_type(self): + def python_type(self) -> type: return frozenset - def as_python_constant(self): + def as_python_constant(self) -> Any: return frozenset({k.vt.as_python_constant() for k in self.set_items}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -1282,11 +1332,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a frozenset") elif name == "__init__": @@ -1305,7 +1355,7 @@ def call_method( "symmetric_difference", ): r = super().call_method(tx, name, args, kwargs) - return FrozensetVariable(r.items) + return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) @@ -1313,11 +1363,11 @@ class DictKeySetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "dict_keys([])" else: @@ -1327,33 +1377,35 @@ def debug_repr(self): + "])" ) - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Already EQUALS_MATCH guarded pass @property - def set_items(self): + def set_items(self) -> Any: return self.items - def python_type(self): + def python_type(self) -> type: return dict_keys - def as_python_constant(self): + def as_python_constant(self) -> Any: return dict.fromkeys( {k.vt.as_python_constant() for k in self.set_items}, None ).keys() def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a dict_keys") return super().call_method(tx, name, args, kwargs) @@ -1368,42 +1420,47 @@ class DictViewVariable(VariableTracker): kv: Optional[str] = None - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @property - def view_items(self): + def view_items(self) -> Any: + assert self.kv is not None return getattr(self.dv_dict.items, self.kv)() @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items # Implement in the subclasses raise NotImplementedError - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.view_items_vt - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.kv is not None codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": @@ -1417,24 +1474,24 @@ class DictKeysVariable(DictViewVariable): kv = "keys" @property - def set_items(self): + def set_items(self) -> set[VariableTracker]: return set(self.view_items) @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [x.vt for x in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_keys def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name in ( @@ -1449,13 +1506,13 @@ def call_method( ): # These methods always returns a set m = getattr(self.set_items, name) - r = m(args[0].set_items) + r = m(args[0].set_items) # type: ignore[attr-defined] return SetVariable(r) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) @@ -1465,10 +1522,10 @@ class DictValuesVariable(DictViewVariable): kv = "values" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: return list(self.view_items) - def python_type(self): + def python_type(self) -> type: return dict_values @@ -1476,14 +1533,20 @@ class DictItemsVariable(DictViewVariable): kv = "items" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_items - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # TODO(guilhermeleobas): This should actually check if args[0] # implements the mapping protocol. if name == "__eq__": diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index eb39dd8fa3e0..187055c26cd0 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,7 +20,8 @@ import functools import inspect -from typing import Any, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 5970ba0e1dda..be765cbbc8bf 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable -from typing import Any, Sequence, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 289cebbe8129..c09cc2163a5f 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,7 +22,8 @@ import logging import weakref -from typing import Any, Iterable, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 85977104977f..644c269a23a3 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable -from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 75928842cf29..629bf094dc95 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from inspect import getattr_static -from typing import Any, Sequence, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fbc0eed3a99f..65b4add4232f 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,14 +1,16 @@ import collections -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable from torch._dynamo.variables.lists import TupleVariable -from torch.fx import Proxy +from torch.fx import has_side_effect, Proxy from .. import graph_break_hints from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented_v2 +from ..graph_bytecode_inputs import get_external_object_by_index from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import FxTracebackAnnotateVariable @@ -26,46 +28,93 @@ Tensor = torch.Tensor +def _get_stream_by_index(index: int) -> torch.Stream: + stream = get_external_object_by_index(index) + assert isinstance(stream, torch.Stream), ( + f"Fork/join stream expected a stream object at index {index}" + ) + return stream + + +def _get_event_by_index(index: int) -> torch.Event: + event = get_external_object_by_index(index) + assert isinstance(event, torch.Event), ( + f"Record/wait event expected an event object at index {index}" + ) + return event + + @custom_op("streams::fork", mutates_args=()) def fork_stream( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: - pass + torch.accelerator.set_stream(_get_stream_by_index(to_index)) @fork_stream.register_fake def _( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: pass +has_side_effect(torch.ops.streams.fork.default) + + @custom_op("streams::join", mutates_args=()) -def join_stream( +def join_stream(from_index: int, to_index: int) -> None: + torch.accelerator.set_stream(_get_stream_by_index(to_index)) + + +@join_stream.register_fake +def _( from_index: int, - from_device: torch.device, to_index: int, - to_device: torch.device, ) -> None: pass -@join_stream.register_fake +has_side_effect(torch.ops.streams.join.default) + + +@custom_op("streams::record_event", mutates_args=()) +def record_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.record_event(event) + + +@record_event.register_fake def _( - from_index: int, - from_device: torch.device, - to_index: int, - to_device: torch.device, + event_index: int, + stream_index: int, +) -> None: + pass + + +has_side_effect(torch.ops.streams.record_event.default) + + +@custom_op("streams::wait_event", mutates_args=()) +def wait_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.wait_event(event) + + +@wait_event.register_fake +def _( + event_index: int, + stream_index: int, ) -> None: pass +has_side_effect(torch.ops.streams.wait_event.default) + + class SymbolicStreamState: """Track the currently entered stream if any""" @@ -116,11 +165,7 @@ def create( **kwargs, ) - def __init__( - self, - stream: Optional["StreamVariable"], - **kwargs: dict[str, Any], - ) -> None: + def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: self.stream = stream super().__init__( target_values={"stream": self.get_stream().user_object_index}, @@ -129,14 +174,16 @@ def __init__( ) def enter( - self, tx: "InstructionTranslator", *args: tuple[Any] - ) -> "VariableTracker": + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are entering the target, and leaving the initial stream tx.symbolic_stream_state.enter_stream(self.get_stream()) return super().enter(tx) - def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker": + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are leaving the target, and entering the initial stream tx.symbolic_stream_state.exit_stream() @@ -182,7 +229,7 @@ def call_method( name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: assert hasattr(self.value, name), f"no stream method found named {name}" from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c2e3df8e4adc..be28fe9269f4 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -408,6 +408,7 @@ def call_function( torch.cuda.amp.autocast, torch.cpu.amp.autocast, ): + # pyrefly: ignore [bad-argument-type] return AutocastModeVariable.create(self.value, args, kwargs) elif self.value in ( # NOTE any class added here must align with the semantic diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 378e9258459f..4d0f0b4fae8a 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Sequence +from collections.abc import Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree @@ -164,7 +164,8 @@ def __init__( if value is not None: super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code + # needed for BC with calling enter from CM code + self.cm_obj = value # type: ignore[assignment] self.source = source # type: ignore[assignment] def reconstruct(self, codegen: "PyCodegen") -> None: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 707ad7b3d9d1..9dd154dacbb9 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -419,9 +419,7 @@ def call_method( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): - from .builtin import BuiltinVariable - - return BuiltinVariable.call_custom_dict_fromkeys( + return variables.BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) elif self.value is collections.OrderedDict and name == "move_to_end": @@ -501,15 +499,18 @@ def call_function( [self, *args], kwargs, ) - elif ( - self.value is collections.defaultdict - and len(args) <= 1 - and DefaultDictVariable.is_supported_arg(args[0]) - ): + elif self.value is collections.defaultdict: + if len(args) == 0: + default_factory = variables.ConstantVariable.create(None) + else: + default_factory, *args = args + dict_vt = variables.BuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) return DefaultDictVariable( - {}, + dict_vt.items, collections.defaultdict, - args[0], + default_factory, mutation_type=ValueMutationNew(), ) elif is_typeddict(self.value): @@ -968,6 +969,12 @@ def __init__( # rid of these workarounds here and in `GetAttrVariable`. self.attrs_directly_modifed_on_dict = set() + import torch.utils._pytree as pytree + + self.is_pytree_constant_class = pytree.is_constant_class(self.value_type) + if pytree.is_constant_class(self.value_type) and self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + def __str__(self) -> str: inner = self.value_type.__name__ if inner in [ @@ -989,12 +996,10 @@ def python_type(self): return self.value_type def as_python_constant(self): - import torch.utils._pytree as pytree - - if pytree.is_constant_class(self.value_type): - if self.source is not None: - install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) - return self.value + if self.is_pytree_constant_class and self.source: + # NOTE pytree constants created in the torch.compile region will + # NOT be guarded (even though they have a source set) + return self.value # TODO else try reconstructing the object by, e.g., leveraging side # effects and `as_python_constant`. return super().as_python_constant() diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 89b6e3297933..58de4fd20c95 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -443,7 +443,7 @@ def __init__( self.blocks_to_lifted_attrs = blocks_to_lifted_attrs # Populate methods for the standard operators. - for k in kind_to_standard_operators.keys(): + for k in kind_to_standard_operators: handler_func_name = ir_name_to_func_name(k) # Create an indirect function call: # convert__ --> lambda node: _convert_standard_operator(node) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 9c4629f13337..e328422ec5e6 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -617,7 +617,7 @@ def get_triton_kernel_and_cache_entry(node: torch.fx.Node): return actual_kernel, matching_entries[0][1] if is_autotuner: - for sig_key, cache_entry in matching_entries: + for _sig_key, cache_entry in matching_entries: entry_metadata = cache_entry.metadata # pyrefly: ignore [missing-attribute] for config in kernel.configs: diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ce01e37f0324..7e608933b34c 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,9 +22,10 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 60ee3bc2973b..b11eb87dc172 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -25,6 +25,9 @@ if TYPE_CHECKING: from collections.abc import Sequence +import threading +from contextlib import contextmanager + import torch import torch.utils._pytree as pytree import torch.utils.dlpack @@ -97,6 +100,43 @@ ) +_thread_local = threading.local() + + +# Saved tensor hooks context +# Compiled saved tensor hooks are convenient way to inline some logic in the graphs +# for saved nodes from forward to backward. (E.g. activations quantization) +# In base implementation user does not have any additional information about saved value +# in the hook, except FakeTensor shape, dtype, device etc. +# _get_saved_tensor_hook_context gives additional graph information about that saved value, +# that can be used to make a decisions which pack/unpack to apply for particular saved value. +# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in +# graph aware way. +# Alternative to this will be making user to write a custom pass that mucks with forward outputs, +# backward input metadata, which requires significantly more effort. +# +# As for now in context we expose forward graph, backward graph and current saved node, +# which contains node.meta with additional information about that fx.Node. +# Warning: This API may change without backward compatibility. +@contextmanager +def _saved_tensor_hook_context(state: dict[str, Any]): + previous_state = getattr(_thread_local, "state", None) + try: + _thread_local.state = state + yield + finally: + # Clean up: restore previous state or remove attribute + if previous_state is not None: + _thread_local.state = previous_state + else: + if hasattr(_thread_local, "state"): + delattr(_thread_local, "state") + + +def _get_saved_tensor_hook_context() -> dict[str, Any] | None: + return getattr(_thread_local, "state", None) + + zip = strict_zip log = logging.getLogger(__name__) @@ -1097,7 +1137,11 @@ def _gen_unused_name(candidate: str): if not isinstance(val, torch.Tensor): continue - pack_out_val = pack_hook_gm(val) + def _get_extra_info() -> dict[str, Any]: + return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved} + + with _saved_tensor_hook_context(_get_extra_info()): + pack_out_val = pack_hook_gm(val) requires_sc_handling = any( is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) @@ -1109,16 +1153,17 @@ def _gen_unused_name(candidate: str): " in the pack hook, and reconstructing the subclass in the unpack hook" ) - pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) - pack_g = pack_gm.graph - maybe_log_graph( - pack_gm, - f"saved_tensors_pack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_pack {saved.name}", - structured_logs, - ) - pack_out_val = pack_gm(val) + with _saved_tensor_hook_context(_get_extra_info()): + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) # Install pack hook graph as eiplogue of fw_module. # Saved tensor output becomes input of pack hook graph. @@ -1188,15 +1233,16 @@ def _gen_unused_name(candidate: str): # Install unpack hook graph as a prologue of backward graph # Saved tensors inputs are replaced with packed tensors and packed sym scalars. # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. - unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) - unpack_g = unpack_gm.graph - maybe_log_graph( - unpack_gm, - f"saved_tensors_unpack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", - structured_logs, - ) + with _saved_tensor_hook_context(_get_extra_info()): + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) def find_saved_in_bw_inputs(bw_inputs): for n in bw_inputs: diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 4846f1ca74ed..86202e2cd319 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -2365,8 +2365,6 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): - from torch._inductor.async_compile import async_compile_pool_manager - # compiled autograd reimplements this function at proxy_call_aot_backward assert not backward_state_indices, ( "BackwardState requires CompiledAutograd" @@ -2446,7 +2444,6 @@ def _backward_impl(ctx, all_args): with ( tracing(saved_context), compile_context(saved_compile_context), - async_compile_pool_manager(), context(), track_graph_compiling(aot_config, "backward"), metrics_context, diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 844f34bb576d..9fbb5e5fe984 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -3,6 +3,7 @@ Contains various utils for AOTAutograd, including those for handling collections. """ +import copy import dataclasses import logging import operator @@ -459,7 +460,9 @@ def _copy_metadata_to_bw_nodes_in_subgraph( node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack") node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") # TODO: better to change to a specific field of custom? - node.meta["custom"] = fwd_node.meta.get("custom") + custom = fwd_node.meta.get("custom") + if custom is not None: + node.meta["custom"] = copy.deepcopy(custom) def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: diff --git a/torch/_guards.py b/torch/_guards.py index bac59965a3ae..32b796d71eea 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -145,6 +145,7 @@ class GuardSource(enum.Enum): GLOBAL_UNSPECIALIZED_NN_MODULE = 13 LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14 GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15 + TEMP_LOCAL = 16 def is_fsdp_module(self) -> bool: return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) @@ -903,7 +904,7 @@ def patch(**kwargs: Any) -> Generator[None, None, None]: prior = {} ctx = TracingContext.get() - for key in kwargs.keys(): + for key in kwargs: # KeyError on invalid entry prior[key] = getattr(ctx, key) for key, val in kwargs.items(): diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 516d58bdf314..452a080570eb 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -24,6 +24,7 @@ from torch._higher_order_ops.local_map import local_map_hop from torch._higher_order_ops.map import map from torch._higher_order_ops.out_dtype import out_dtype +from torch._higher_order_ops.print import print from torch._higher_order_ops.run_const_graph import run_const_graph from torch._higher_order_ops.scan import scan from torch._higher_order_ops.strict_mode import strict_mode @@ -75,4 +76,5 @@ "map", "while_loop_stack_output", "local_map_hop", + "print", ] diff --git a/torch/_higher_order_ops/print.py b/torch/_higher_order_ops/print.py new file mode 100644 index 000000000000..5a14ef23aa24 --- /dev/null +++ b/torch/_higher_order_ops/print.py @@ -0,0 +1,44 @@ +import builtins + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator + + +class Print(HigherOrderOperator): + """ + print(format_str, **kwargs) -> None + + This Higher Order Operator (HOP) provides a functional version of print for use in PyTorch graphs. + It enables format printing with named arguments, e.g., torch._higher_order_ops.print("moo {x} {y}", x=1, y=2). + + This HOP enables printing without causing graph break. + """ + + def __init__(self) -> None: + super().__init__("print") + + def __call__(self, format_str: str, **kwargs: object) -> object: + assert isinstance(format_str, str) + return super().__call__(format_str, **kwargs) + + +print = Print() + + +@print.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +# pyre-ignore +def print_cpu(format_str: str, **kwargs: object) -> None: + # Ensure all immutable_dict/list in kwargs are converted to regular dict/list + map_types: dict[type, type] = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_kwargs = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + kwargs, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + # Use built-in print to avoid recursion with the HOP print + builtins.print(format_str.format(**new_kwargs)) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 8ffab3769942..0e398897a7ea 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -498,6 +498,7 @@ def get_signature_value(idx: int, arg: Any) -> str: # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() + # pyrefly: ignore[missing-argument,bad-argument-type] ttir_module = src.make_ir(options, codegen_fns, module_map, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index ac0d60bdebd7..5ede0cd08501 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -2,7 +2,6 @@ from __future__ import annotations import atexit -import contextlib import functools import json import logging @@ -230,18 +229,6 @@ def remove_future(kernel_src: str) -> None: del CompiledTritonKernels._cache[key] -@contextlib.contextmanager -def async_compile_pool_manager(): - """ - Context manager to quiesce the subproc pool at the end of compilation, i.e., - when dynamo is done. - """ - try: - yield - finally: - AsyncCompile.quiesce() - - class AsyncCompile: """ Utilities to compile in thread pools or subprocess pools (in the case of Triton). @@ -277,7 +264,9 @@ def process_pool() -> AnyPool: pool: AnyPool if config.worker_start_method == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(get_compile_threads()) + pool = SubprocPool( + get_compile_threads(), quiesce=config.quiesce_async_compile_pool + ) else: if config.worker_start_method == "spawn": # Avoid creating pools in the spawned subprocs themselves: @@ -333,20 +322,6 @@ def use_process_pool(cls): cls._ready_future = cls.process_pool().submit(cls._get_ready) return cls._ready_future.done() - @classmethod - def quiesce(cls) -> None: - """ - If using a SubprocPool, signal the sidecar process to shut down its - ProcessPoolExecutor. - """ - # Don't inadvertently create a process pool if it doesn't already exist: - if not cls.process_pool.cache_info().currsize: - return - if config.quiesce_async_compile_pool: - pool = cls.process_pool() - if isinstance(pool, SubprocPool): - pool.quiesce() - @classmethod def wakeup(cls) -> None: """ @@ -626,6 +601,42 @@ def task(): future = self.submit(task) return LambdaFuture(lambda: future.result()) + def pallas(self, kernel_name: str, source_code: str): + """ + Compile Pallas (JAX experimental) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the Pallas kernel, as a string + + Note: + Pallas kernels are Python code that uses JAX and Pallas APIs. + We use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.pallas import MAIN_SUFFIX, PallasKernelWrapper + + kernel_code_log.info("Pallas Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find Pallas main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return PallasKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index 81dca605940e..5a70a34f7b64 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -164,7 +164,7 @@ def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> No self.extra_uses[new_node].add(updated_use) # Clean up erased nodes - for old_node in erased_merge_sets.keys(): + for old_node in erased_merge_sets: self.extra_deps[old_node].clear() self.extra_uses[old_node].clear() del self.merge_sets[old_node] diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index a227239356a6..bc8dba511925 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -86,7 +86,7 @@ def swap_submodules( self, submodules: dict[str, Callable[..., Any]] ) -> dict[str, Callable[..., ValueRanges[Expr]]]: result: dict[str, Callable[..., ValueRanges[Expr]]] = {} - for key in submodules.keys(): + for key in submodules: if key == "get_index": result[key] = self.get_index elif "masked_subblock" in key: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index cf17bf2e9478..958349429926 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1681,7 +1681,7 @@ def set( if config.aot_inductor.emit_multi_arch_kernel: bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} - assert bin_type in bin_type_to_ext.keys(), ( + assert bin_type in bin_type_to_ext, ( "multi_arch_kernel_binary only supported in CUDA/XPU" ) base_path, _ = os.path.splitext(bin_path) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e6a5c5e8ec17..730c03f1c813 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -510,6 +510,7 @@ def init_backend_registration() -> None: from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling + from .pallas import PallasScheduling from .python_wrapper_mtia import PythonWrapperMtia from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen @@ -536,6 +537,7 @@ def init_backend_registration() -> None: cuda_backends = { "triton": CUDACombinedScheduling, "halide": HalideScheduling, + "pallas": PallasScheduling, } register_backend_for_device( "cuda", diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 02129fff2416..fad4ce84f297 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -337,7 +337,7 @@ def process_args_for_input_shape(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ) or arg_type in (sympy.Integer, int, sympy.Float, float): write_dummy_scalar_ivalue(arg_name) elif arg_signature and arg_signature.startswith("tensordesc<"): @@ -719,7 +719,7 @@ def process_args(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ): code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py new file mode 100644 index 000000000000..da437a4e8ee3 --- /dev/null +++ b/torch/_inductor/codegen/pallas.py @@ -0,0 +1,426 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import hashlib +from typing import Any, Optional, TYPE_CHECKING + +import sympy # noqa: TC002 + +import torch # noqa: TC001 +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import get_fused_kernel_name, get_kernel_metadata +from ..virtualized import V +from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides +from .simd import SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from ..ir import IRNode + from ..scheduler import BaseSchedulerNode + + +# Main function suffix used in generated Pallas code +MAIN_SUFFIX = "main" + +# Logger for Pallas kernel code +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class PallasKernelWrapper: + """Wrapper to provide .run() interface for Pallas kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("Pallas kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the Pallas kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +class Unsupported(RuntimeError): + """Exception raised when an operation is not supported by the Pallas backend.""" + + +class PallasKernelOverrides(OpOverrides): + """ + Map element-wise ops to JAX/Pallas operations. + + For now, we use the default Python operators which are compatible + with JAX numpy broadcasting semantics. + """ + + @staticmethod + def sin(x: str) -> str: + return f"jnp.sin({x})" + + @staticmethod + def cos(x: str) -> str: + return f"jnp.cos({x})" + + @staticmethod + def tan(x: str) -> str: + return f"jnp.tan({x})" + + @staticmethod + def sinh(x: str) -> str: + return f"jnp.sinh({x})" + + @staticmethod + def cosh(x: str) -> str: + return f"jnp.cosh({x})" + + @staticmethod + def tanh(x: str) -> str: + return f"jnp.tanh({x})" + + @staticmethod + def asin(x: str) -> str: + return f"jnp.arcsin({x})" + + @staticmethod + def acos(x: str) -> str: + return f"jnp.arccos({x})" + + @staticmethod + def atan(x: str) -> str: + return f"jnp.arctan({x})" + + @staticmethod + def exp(x: str) -> str: + return f"jnp.exp({x})" + + @staticmethod + def exp2(x: str) -> str: + return f"jnp.exp2({x})" + + @staticmethod + def expm1(x: str) -> str: + return f"jnp.expm1({x})" + + @staticmethod + def log(x: str) -> str: + return f"jnp.log({x})" + + @staticmethod + def log10(x: str) -> str: + return f"jnp.log10({x})" + + @staticmethod + def log2(x: str) -> str: + return f"jnp.log2({x})" + + @staticmethod + def log1p(x: str) -> str: + return f"jnp.log1p({x})" + + @staticmethod + def sqrt(x: str) -> str: + return f"jnp.sqrt({x})" + + @staticmethod + def rsqrt(x: str) -> str: + return f"(1.0 / jnp.sqrt({x}))" + + @staticmethod + def abs(x: str) -> str: + return f"jnp.abs({x})" + + @staticmethod + def neg(x: str) -> str: + return f"(-{x})" + + @staticmethod + def floor(x: str) -> str: + return f"jnp.floor({x})" + + @staticmethod + def ceil(x: str) -> str: + return f"jnp.ceil({x})" + + @staticmethod + def trunc(x: str) -> str: + return f"jnp.trunc({x})" + + @staticmethod + def round(x: str) -> str: + return f"jnp.round({x})" + + @staticmethod + def sigmoid(x: str) -> str: + return f"(1.0 / (1.0 + jnp.exp(-{x})))" + + @staticmethod + def relu(x: str) -> str: + return f"jnp.maximum({x}, 0)" + + @staticmethod + def pow(a: str, b: str) -> str: + return f"jnp.power({a}, {b})" + + @staticmethod + def maximum(a: str, b: str) -> str: + return f"jnp.maximum({a}, {b})" + + @staticmethod + def minimum(a: str, b: str) -> str: + return f"jnp.minimum({a}, {b})" + + @staticmethod + def where(cond: str, a: str, b: str) -> str: + return f"jnp.where({cond}, {a}, {b})" + + +class PallasKernel(SIMDKernel): + """ + Minimal Pallas kernel for simple elementwise operations. + + Strategy: + - Treat loads as full-array refs: "in_ptrX[...]" + - Compute expression with Python operators (compatible with jax.numpy broadcasting) + - Store as full-array ref assignment: "out_ptrY[...] = " + - Generate Python code that defines a Pallas kernel and a host entrypoint. + - Use async_compile.pallas path to compile and load Python code. + """ + + overrides = PallasKernelOverrides # type: ignore[assignment] + + def _get_contiguous_index_str(self, index: sympy.Expr) -> str: + """ + Validate that the index represents contiguous access and return the indexing string. + + For Pallas, we only support simple contiguous access patterns where the index + is a single symbol (e.g., xindex) representing a flattened iteration. + This ensures the load/store order is contiguous. + + Args: + index: The indexing expression to validate + + Returns: + The indexing string to use (currently always "...") + + Raises: + Unsupported: If the index is not a simple contiguous pattern + """ + # Prepare and simplify the index + prepared_index = self.prepare_indexing(index) + + # For contiguous access, we expect a single symbol (like xindex) + # or a simple integer (for scalar operations) + if isinstance(prepared_index, sympy.Symbol): + # This is the expected case: a single symbol representing contiguous iteration + return "..." + elif prepared_index.is_Integer: + # Scalar case + return "..." + else: + # If there's any complex expression (ModularIndexing, FloorDiv, etc.), + # it's not a simple contiguous pattern + raise Unsupported( + f"Pallas backend only supports contiguous access patterns. " + f"Got complex index: {prepared_index}" + ) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override] + buf = self.args.input(name) + dtype = V.graph.get_dtype(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must be unpacked with [...] to load the array + return self.cse.generate( + self.compute, + f"{buf}[{index_str}]", + dtype=dtype, + ) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None + ) -> None: # type: ignore[override] + if mode is not None: + raise Unsupported("pallas store mode not supported") + out = self.args.output(name) + self.store_buffer_names.add(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must use [...] assignment to store back to the ref + self.stores.writeline(f"{out}[{index_str}] = {value}") + + def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override] + """ + Generate the complete Pallas kernel code as a Python string. + + This includes: + - Import statements for JAX/Pallas + - The kernel function that operates on refs + - The main wrapper function that handles PyTorch<->JAX conversions via DLPack + + Args: + name: Optional kernel name (will use placeholder if not provided) + + Returns: + str: Complete Python source code for the Pallas kernel + """ + # Ensure one (1) output for now + live_outs = list(self.args.live_output_buffers()) + if len(live_outs) != 1: + raise Unsupported( + "Pallas backend currently supports single-output elementwise kernels only" + ) + + code = IndentedBuffer() + code.splice( + """ + import torch + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + from torch.utils import dlpack as torch_dlpack + """, + strip=True, + ) + + # Define the Pallas kernel: accepts refs, uses broadcasted expressions + arg_defs, _, _, _ = self.args.python_argdefs() + # Order: inputs (in_ptr*), then outputs (out_ptr*), then sizes/workspaces + kernel_params = [a.name for a in arg_defs] + + kernel_name = name or "" + code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):") + with code.indent(): + # Emit compute (CSE) and store lines; they reference *_ptr[...] directly + for line in self.compute._lines: + code.writeline(str(line)) + for line in self.stores._lines: + code.writeline(str(line)) + + # Host entry: convert torch tensors <-> jax, call pallas_call and copy back + main_name = f"{kernel_name}_main" + code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):") + with code.indent(): + # Identify inputs (in_ptr*) and output (out_ptr*) + input_params = [ + p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr")) + ] + output_params = [p for p in kernel_params if p.startswith("out_ptr")] + + if len(output_params) != 1: + raise RuntimeError( + f"Expected exactly 1 output, got {len(output_params)}" + ) + + output_param = output_params[0] + + # Convert inputs to JAX arrays + code.writeline("# Convert Torch -> JAX for inputs") + for inp in input_params: + code.writeline( + f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))" + ) + + # Get output spec from PyTorch tensor + code.writeline("# Prepare output spec from PyTorch tensor") + code.writeline("# Map PyTorch dtype to JAX dtype string") + code.writeline("_torch_dtype_to_jax = {") + code.writeline( + " torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16," + ) + code.writeline( + " torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8," + ) + code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,") + code.writeline("}") + code.writeline( + f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])" + ) + + # Call pallas + code.writeline("compiled = pl.pallas_call(") + code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),") + code.writeline(" out_shape=out_spec,") + code.writeline(" grid=(1,),") + code.writeline(")") + + jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params]) + code.writeline(f"res = compiled({jax_input_args})") + + # Copy result back + code.writeline("# Copy result back into the provided torch output tensor") + code.writeline( + "res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))" + ) + code.writeline(f"{output_param}.copy_(res_t)") + + return code.getvalue() + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: # type: ignore[override] + """Generate the Python code that calls this Pallas kernel.""" + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + + # Generate kernel call: kernel_name.run(arg1, arg2, ...) + # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper + # which exposes a run() method + kernel_call = f"{name}.run({', '.join(map(str, call_args))})" + wrapper.writeline(kernel_call) + + +class PallasScheduling(SIMDScheduling): + kernel_type = PallasKernel # type: ignore[assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + # Start minimal: no special features advertised + return OrderedSet() + + def define_kernel( + self, + src_code: str, + node_schedule: Sequence[BaseSchedulerNode], + kernel: PallasKernel, + ) -> str: # type: ignore[override] + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + return wrapper.src_to_kernel[src_code] + + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"pallas_{kernel_hash}" + else: + kernel_name = f"pallas_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + + # Replace placeholder if any + src_code = src_code.replace("", kernel_name) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment) + + return kernel_name diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 4cc3f0ef282a..1c1f0f1c9cd2 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -24,6 +24,22 @@ log = logging.getLogger(__name__) +def inline_subgraph_to_ir_nodes( + gm: torch.fx.GraphModule, inputs: list[Any], name: str +) -> Any: + """Inline a subgraph by converting its FX operations to individual IR nodes. + + This converts a subgraph to multiple ComputedBuffer nodes (fusable), + enabling epilogue fusion with subsequent operations. + + Returns: + TensorBox containing the final operation result as individual IR nodes + """ + from torch._inductor.lowering import process_subgraph_nodes + + return process_subgraph_nodes(gm, inputs) + + class SubgraphChoiceCaller(ir.ChoiceCaller): """ Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary @@ -261,7 +277,14 @@ def make_fx_graph( # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs from torch.fx.experimental.proxy_tensor import make_fx - return make_fx(functools.partial(decomp, **decomp_kwargs))(*args) + from ..decomposition import select_decomp_table + + decomposition_table = select_decomp_table() + + return make_fx( + functools.partial(decomp, **decomp_kwargs), + decomposition_table=decomposition_table, + )(*args) # Generate descriptive name for this variant variant_name = self._generate_variant_name(decomp, decomp_kwargs) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index e86753348c6b..41b12d05cd32 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -98,7 +98,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: - log.warning( + log.debug( "ComboKernels: %d long reduction nodes are separated", len(long_reduction), ) @@ -112,7 +112,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] if large_pointwise: # TODO benchmark the performance when large pointwise nodes combining with others - log.warning( + log.debug( "ComboKernels: %d large pointwise nodes are separated", len(large_pointwise), ) @@ -627,7 +627,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) @@ -699,7 +699,7 @@ def get_block_args(self) -> list[ConstexprArg]: block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix self.block_args = list(block_names.keys()) - return [ConstexprArg(x) for x in block_names.keys()] + return [ConstexprArg(x) for x in block_names] def add_numel_to_args( self, argdefs: list[ArgName], signature: list[Any] diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e629d9c7bdeb..947166cf216c 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,7 +2063,8 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index afa569ff97da..74a58acb84ff 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -359,6 +359,7 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, + # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. use_nccl_estimator: bool = True, ) -> float: """ diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 6c7c9a8bd7da..29efcb4a4449 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -18,13 +18,12 @@ from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet -from . import config, ir +from . import config, config_comms, ir from .dependencies import WeakDep if TYPE_CHECKING: from .ir import IRNode, Operation - from .scheduler import SchedulerBuffer from .memory import ( estimate_peak_memory, @@ -155,12 +154,15 @@ class ReorderInfo: Debug info describing how an individual snode was reordered """ - initial_exposed: float = -1 - final_exposed: float = -1 limiting_factor: str = "None" moves: int = 0 grouped: int = 0 grouped_info: str = "" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" @property def improvement(self): @@ -193,7 +195,7 @@ def contains_gemm_like(snode: BaseSchedulerNode) -> bool: return is_gemm_like(snode.node) -def _temp_group_visit_leaves(snode, fn): +def _temp_group_visit_leaves(snode: BaseSchedulerNode, fn): from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping: @@ -203,6 +205,126 @@ def _temp_group_visit_leaves(snode, fn): fn(snode) +def wait_exposed_communication_time( + snodes_to_wait: list[BaseSchedulerNode], runtimes: dict[BaseSchedulerNode, float] +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a wait operation by finding its corresponding + collective and accumulating overlapping compute time between them. + + The Wait node must be the last in snodes_to_wait. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + wait_snode = snodes_to_wait[-1] + assert is_wait(wait_snode.node) + assert len(snodes_to_wait) > 1 + idx = len(snodes_to_wait) - 2 + comm_time = 0.0 + comp_time = 0.0 + overlap_info = "" + waits_found = [] + for i in range(idx, -1, -1): + c = snodes_to_wait[i] + if contains_wait(c): + waits_found.append(c) + if contains_collective(c): + if is_corresponding_collective_wait(c, wait_snode): + comm_time = runtimes[c] + overlap_info += f"->C[{c.get_name()}]" + break + + if not contains_async_collective(c): + # Sync Collective + comp_time = 0.0 + continue + else: + for w in waits_found: + if is_corresponding_collective_wait(c, w): + # Similar to Sync Collective + # If after our Collective exist another Collective-Wait, + # All compute after it will not be overlapping + comp_time = 0.0 + continue + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(c, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{c.get_name()}[{comp_time_after - comp_time_before}]" + + return comm_time, comp_time, overlap_info + + +def coll_exposed_communication_time( + snodes: list[BaseSchedulerNode], + runtimes: dict[BaseSchedulerNode, float], +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a collective operation by finding its corresponding + wait and accumulating compute time that can overlap with communication. + + The Collective node must be the first in snodes. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + collective_snode = snodes[0] + comm_time = runtimes[collective_snode] + comp_time = 0.0 + collective_outs: OrderedSet[str] = OrderedSet( + o.get_name() for o in collective_snode.get_outputs() + ) + overlap_info = "" + collectives_found: list[BaseSchedulerNode] = [] + for snode in snodes[1:]: + # We may have some ops without Wait, + # e.g. DTensor torch.ops._dtensor.shard_dim_alltoall + unmet_deps = OrderedSet( + d.name for d in snode.unmet_dependencies if not _is_fake_dep(d) + ) + + if unmet_deps & collective_outs: + overlap_info += f"->W[{snode.get_name()}]" + break + + if contains_collective(snode): + if not contains_async_collective(snode): + break + else: + collectives_found.append(snode) + continue + if contains_wait(snode): + has_wait_for_collectives_found = False + for coll in collectives_found: + if is_corresponding_collective_wait(collective_snode, snode): + has_wait_for_collectives_found = True + break + if has_wait_for_collectives_found: + # Any compute after not overlapping original Collective + break + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(snode, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]" + return comm_time, comp_time, overlap_info + + def _group_name(snode, with_bufs=False) -> str: ret = "" for n in snode.snodes: @@ -258,369 +380,361 @@ def _initialize_double_linked_list( return _prev, _next, _head -def _reorder_communication_preserving_peak_memory_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: +def is_corresponding_collective_wait( + collective_snode: BaseSchedulerNode, wait_snode: BaseSchedulerNode +) -> bool: """ - Internal testing helper that also returns debug info. + Check if a wait node corresponds to a given collective node by verifying if the wait + depends on outputs from the collective. + """ + collective_outs = OrderedSet(o.get_name() for o in collective_snode.get_outputs()) + unmet_deps = OrderedSet(d.name for d in wait_snode.unmet_dependencies) + return bool(unmet_deps & collective_outs) + + +def _op_runtime_estimate_mult(snode): + # Apply multipliers for faster experimentation. + # TODO(ivankobzarev): Remove after confirmation that runtime estimations are correct. + if contains_collective(snode): + return config_comms.reorder_sink_runtime_estimations_comm_mult + + return config_comms.reorder_sink_runtime_estimations_non_comm_mult + + +def is_async_collective(snode): + """ + Filtering out ops that contain Collective and Wait inside and considered as Collectives. + See contains_collective function. + If the op contains Wait inside - consider as Synchronous compute. + """ + if python_kernel_name := getattr(snode.node, "python_kernel_name", None): + if "torch.ops._dtensor.shard_dim_alltoall.default" in python_kernel_name: + return False + + return True + + +def contains_async_collective(snode): + return contains_collective(snode, is_async_collective) + + +def _group_nodes_from_linked_list( + head: Optional[BaseSchedulerNode], + tail: Optional[BaseSchedulerNode], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], +) -> list[BaseSchedulerNode]: + """ + Traverse doubly-linked list from head to tail and return nodes as a list. + + Args: + head: Starting node of the segment + tail: Ending node of the segment (inclusive) + next_dict: Dictionary mapping each node to its next node + Returns: - - reordered snodes list - - dict {snode: ReorderInfo} + List of nodes from head to tail (inclusive) """ - has_collectives = False - for snode in snodes: - if contains_collective(snode): - has_collectives = True + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: break - if not has_collectives: - return snodes, {} - - from torch._inductor.scheduler import GroupedSchedulerNode + n = next_dict[n] # type: ignore[index] + return ret - original_snodes_num = len(snodes) - # heuristic to avoid degenerating to quadratic time - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) - runtimes: dict[BaseSchedulerNode, float] = { - snode: estimate_op_runtime(snode) for snode in snodes - } - # debug stats - stats: dict[BaseSchedulerNode, ReorderInfo] = {} - def exposed_communication_time( - collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode] - ) -> float: - # assumes a linear schedule and computes the overlap of the collective with the remaining nodes - comm_time = estimate_op_runtime(collective_snode) - compute_time = 0.0 - for snode in remaining_snodes: - if contains_collective(snode): - continue - if contains_wait(snode): - # TODO - if the wait is for a collective that started before this collective or on another stream, - # we can ignore it. Otherwise, it's the end of the road for overlap opportunities - break +def _perform_double_linked_list_swap( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list. - def accumulate_time(_snode: BaseSchedulerNode) -> None: - nonlocal compute_time - compute_time += runtimes[_snode] + Transforms: + candidate_prev -> candidate -> group_head...group_tail -> group_tail_next + Into: + candidate_prev -> group_head...group_tail -> candidate -> group_tail_next - _temp_group_visit_leaves(snode, accumulate_time) - return max(0, comm_time - compute_time) + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list - total_moves = 0 + Returns: + New head of the linked list (may change if candidate was the head) + """ + # 0: Update candidate's previous node + candidate_prev = prev_dict[candidate] + if candidate_prev: + next_dict[candidate_prev] = group_head + prev_dict[group_head] = candidate_prev + + # 2: Update group_tail's next node + group_tail_next = next_dict[group_tail] + if group_tail_next: + prev_dict[group_tail_next] = candidate + next_dict[candidate] = group_tail_next + + # 1: Link group_tail to candidate + prev_dict[candidate] = group_tail + next_dict[group_tail] = candidate + + # Update head if candidate was the head + if head == candidate: + return group_head + return head + + +def _calculate_potential_peak_memory_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + curr_memory: dict, +) -> tuple[int, dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (reorder version). - _prev, _next, _head = _initialize_double_linked_list(snodes) + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation memory values for each node. - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret - - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # swap (candidate, group_head...group_tail) - # Before: - # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next - # After: - # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next - # 0 - candidate_prev = _prev[candidate] - if candidate_prev: - _next[candidate_prev] = group_head - _prev[group_head] = candidate_prev - - # 2 - group_tail_next = _next[group_tail] - if group_tail_next: - _prev[group_tail_next] = candidate - _next[candidate] = group_tail_next - - # 1 - _prev[candidate] = group_tail - _next[group_tail] = candidate - - nonlocal _head - if _head == candidate: - _head = group_head - - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate - ): - # Caching calculations of memory for group nodes and candidate, - # to apply without recalculation after swap. - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - potential_peak: int = 0 - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - # Not accounting for buffers last use change - potential_peak = max( - group_peak_memory - candidate_delta_mem, - _curr_memory[group_tail][1] - - candidate_delta_mem - + candidate_allocfree.size_alloc, - ) - return potential_peak, _post_alloc_update + Args: + candidate: Node being moved + gns: Group nodes + group_tail: Last node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + curr_memory: Current memory state dict - # If candidate will be after group, the starting memory level of group nodes - # changes to the -(candidate.size_alloc - candidate.size_free) - mem_after_reorder_delta: int = -candidate_delta_mem - for gn in gns: - gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta - _post_alloc_update[gn] = gn_post_alloc_mem - potential_peak = max(potential_peak, gn_post_alloc_mem) - - bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None) - if bufs is not None: - for buf in bufs: - # Candidate will deallocate those buffers - mem_after_reorder_delta += buf.mpi_buffer.size_free - - candidate_mem_post_alloc = ( - _curr_memory[group_tail][1] - + mem_after_reorder_delta - + candidate_allocfree.size_alloc + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict) + """ + # Caching calculations of memory for group nodes and candidate, + # to apply without recalculation after swap. + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + potential_peak: int = 0 + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + # Not accounting for buffers last use change + potential_peak = max( + group_peak_memory - candidate_delta_mem, + curr_memory[group_tail][1] + - candidate_delta_mem + + candidate_allocfree.size_alloc, ) - _post_alloc_update[candidate] = candidate_mem_post_alloc - potential_peak = max(potential_peak, candidate_mem_post_alloc) return potential_peak, _post_alloc_update - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ): - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] - candidate_delta_mem, - cm[1] - candidate_delta_mem, - ) - _candidate_post_alloc_mem = ( - _curr_memory[group_tail][1] + candidate_allocfree.size_alloc - ) - _candidate_post_free_mem = ( - _candidate_post_alloc_mem - candidate_allocfree.size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - _candidate_post_free_mem, - ) - return + # If candidate will be after group, the starting memory level of group nodes + # changes to the -(candidate.size_alloc - candidate.size_free) + mem_after_reorder_delta: int = -candidate_delta_mem + for gn in gns: + gn_post_alloc_mem = curr_memory[gn][0] + mem_after_reorder_delta + _post_alloc_update[gn] = gn_post_alloc_mem + potential_peak = max(potential_peak, gn_post_alloc_mem) - # Candidate becomes last use of some bufs - for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn) + if bufs is not None: for buf in bufs: - buf_to_snode_last_use[buf] = candidate - - size_free_to_move_to_candidate_sum: int = 0 - for n in gns: - _gn_post_alloc_mem: int = _post_alloc_update[n] - size_free_to_move_to_candidate: int = sum( - buf.mpi_buffer.size_free - for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] - ) - size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate - # group node does not deallocate this after swap - snodes_allocfree[n].size_free -= size_free_to_move_to_candidate - gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free - _curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) - _candidate_post_alloc_mem = _post_alloc_update[candidate] - snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum - candidate_post_free_mem = ( - _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - candidate_post_free_mem, - ) + # Candidate will deallocate those buffers + mem_after_reorder_delta += buf.mpi_buffer.size_free - debug_num_collectives_to_reorder: Optional[int] = ( - config.reorder_iterative_debug_limit_to_reorder + candidate_mem_post_alloc = ( + curr_memory[group_tail][1] + + mem_after_reorder_delta + + candidate_allocfree.size_alloc ) + _post_alloc_update[candidate] = candidate_mem_post_alloc + potential_peak = max(potential_peak, candidate_mem_post_alloc) + return potential_peak, _post_alloc_update + + +def _update_memory_tracking_after_swap_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + buf_to_snode_last_use: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (reorder version). - num_processed_collectives: int = 0 - curr = _head - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute - iterative_recompute_error = False - - while _next[curr] is not None: - if iterative_recompute_error: - break - # pyrefly: ignore [bad-argument-type] - if contains_collective(curr): - if debug_num_collectives_to_reorder is not None and ( - num_processed_collectives >= debug_num_collectives_to_reorder - ): - break - num_processed_collectives += 1 + Updates curr_memory, buf_to_snode_last_use, and snodes_allocfree dictionaries + to reflect the new memory state after swapping candidate with group. - info = stats[curr] = ReorderInfo() - info.initial_exposed = info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) + Args: + candidate: Node that was moved + gns: Group nodes + group_tail: Last node of group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + post_alloc_update: Cached post-allocation memory values + curr_memory: Current memory state dict (mutated) + buf_to_snode_last_use: Buffer to last-use node mapping (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + for gn in gns: + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] - candidate_delta_mem, + cm[1] - candidate_delta_mem, ) + _candidate_post_alloc_mem = ( + curr_memory[group_tail][1] + candidate_allocfree.size_alloc + ) + _candidate_post_free_mem = ( + _candidate_post_alloc_mem - candidate_allocfree.size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + _candidate_post_free_mem, + ) + return - candidate = _prev[curr] - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] # post_alloc memory - while candidate is not None: - if contains_collective(candidate): - info.limiting_factor = "collective ordering" - break + # Candidate becomes last use of some bufs + for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + for buf in bufs: + buf_to_snode_last_use[buf] = candidate + + size_free_to_move_to_candidate_sum: int = 0 + for n in gns: + _gn_post_alloc_mem: int = post_alloc_update[n] + size_free_to_move_to_candidate: int = sum( + buf.mpi_buffer.size_free + for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] + ) + size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate + # group node does not deallocate this after swap + snodes_allocfree[n].size_free -= size_free_to_move_to_candidate + gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free + curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) + _candidate_post_alloc_mem = post_alloc_update[candidate] + snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum + candidate_post_free_mem = ( + _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + candidate_post_free_mem, + ) - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - curr.scheduler, - gns, - temp_grouping=True, - ) - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) - } - - candidate_outs = candidate.get_outputs() - data_dep = None - for o in candidate_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break +def _find_buffers_with_changed_last_use( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping candidate with group. - if data_dep is not None: + When we swap [candidate [group]] to [[group] candidate], some buffers that + were last used by a group node will now be last used by candidate instead. + This affects memory deallocation timing. - def is_groupable( - candidate: BaseSchedulerNode, - ) -> tuple[bool, Optional[str]]: - # preserve ordering - if contains_collective(candidate): - return False, "contains_collective" + Args: + candidate: The node being moved + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes - if contains_gemm_like(candidate): - return False, "contains_gemm_like" - return True, None + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if candidate not in succ_nodes: + continue - is_groupable_result, grouping_reason = is_groupable(candidate) - if is_groupable_result: - group_head = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _prev[candidate] - continue - else: - msg = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" - f"dep on {_group_names(gns)}" - f"\n non_group_reason:{grouping_reason}" - ) - info.limiting_factor = msg - break + if not any(gn == snode_last_use for gn in gns): + continue - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem: int = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # candidate and one of group nodes are successors of the same buffer - # and last use of the buffer happen in group nodes. - # This last use deallocates it. - # If we swap [candidate [group]] to [[group] candidate], - # candidate becomes the last use - # and deallocated this buffer instead of group node. - # we need to update size_free accordingly to group_node and candidate, - # and recalculate post_alloc, post_free for them. - # - # Buf that changes its last use snode, - # after swap will be deallocated only by candidate, - # while before it was deallocated by group node. - group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if candidate not in succ_nodes: - continue + group_n_to_bufs_after_swap_dealloc_by_candidate[snode_last_use].append(buf) - if not any(gn == snode_last_use for gn in gns): - continue + return group_n_to_bufs_after_swap_dealloc_by_candidate - group_n_to_bufs_after_swap_dealloc_by_candidate[ - snode_last_use - ].append(buf) - potential_peak, _post_alloc_update = _calculate_potential_peak_memory( - candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate - ) +def _is_node_groupable_for_reorder( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped with collective during reordering. - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break - info.moves += 1 - total_moves += 1 + This pass processes collectives left to right, so we avoid grouping with + already-processed collectives based on configuration. - _perform_double_linked_list_swap(candidate, group_head, group_tail) + Args: + candidate: Node to check for groupability - info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) - ) + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # This pass processes collectives left to right, + # Do not group with processed collectives. + # Leaving config for experimentation in 2D + if not config_comms.reorder_iterative_group_with_collectives: + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_collective {candidate.get_name()}", + ) + if not config_comms.reorder_iterative_use_runtime_estimations: + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + +def _format_and_log_reordering_stats( + stats: dict[BaseSchedulerNode, ReorderInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format reordering statistics, log them, and return final node list. - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ) + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - if debug_iterative_memory_recompute: - # Compare iteratively recomputed memory data - # with full run of estimate_peak_memory - - from .comms_debug import _debug_iterative_memory_recompute - - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "reorder_communication_preserving_peak_memory", - group_n_to_bufs_after_swap_dealloc_by_candidate, - ) - if iterative_recompute_error: - break - candidate = _prev[group_head] - curr = _next[curr] # type: ignore[assignment] + Args: + stats: Per-node reordering statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} total_improvement = sum([improvement[snode] for snode in improvement]) @@ -632,28 +746,35 @@ def is_groupable( ) headers = [ "Collective node", - "initial exposed", - "final exposed", - "improvement", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", "limiting factor", "moves", "grouped", "grouped_info", + "overlap_info", ] rows = [ [ node_summary(snode), - node_info.initial_exposed, - node_info.final_exposed, - node_info.improvement, + node_info.comm_time / 1e3, + node_info.comp_time / 1e3, + node_info.initial_exposed / 1e3, + node_info.final_exposed / 1e3, + node_info.improvement / 1e3, node_info.limiting_factor, node_info.moves, node_info.grouped, node_info.grouped_info, + node_info.overlap_info, ] for snode, node_info in node_stats.items() ] if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate reorder_log_str += tabulate( @@ -667,7 +788,7 @@ def is_groupable( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -685,6 +806,334 @@ def is_groupable( payload_fn=lambda: reorder_log_str, ) + return new_snodes + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + # heuristic to avoid degenerating to quadratic time + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + total_moves = 0 + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + debug_num_collectives_to_reorder: Optional[int] = ( + config_comms.reorder_iterative_debug_limit_to_reorder + ) + + num_processed_collectives: int = 0 + curr: Optional[BaseSchedulerNode] = _head + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + iterative_recompute_error = False + + while curr is not None and _next[curr] is not None: + _next_curr = _next[curr] + if iterative_recompute_error: + break + # pyrefly: ignore [bad-argument-type] + if not contains_async_collective(curr): + curr = _next_curr + continue + + if debug_num_collectives_to_reorder is not None and ( + num_processed_collectives >= debug_num_collectives_to_reorder + ): + break + num_processed_collectives += 1 + + info = stats[curr] = ReorderInfo() + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_waits = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] # post_alloc memory + + while candidate is not None: + if config_comms.reorder_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.reorder_iterative_extra_comm_comp_overlap + * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + if ( + not config_comms.reorder_iterative_unsafe_collectives_reorder + and contains_collective(candidate) + ): + info.limiting_factor = "collective ordering" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + curr.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) + } + + candidate_outs = candidate.get_outputs() + data_dep = None + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + + if data_dep is not None: + is_groupable_result, grouping_reason = _is_node_groupable_for_reorder( + candidate + ) + if is_groupable_result: + group_head = candidate + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + if contains_wait(candidate): + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), + runtimes, + ) + group_waits[candidate] = comm_time, comp_time + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _prev[candidate] + continue + else: + msg = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(gns)}" + f"\n non_group_reason:{grouping_reason}" + ) + info.limiting_factor = msg + break + + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_waits) > 0: + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, info.comm_time - info.comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max( + 0, info.comm_time - info.comp_time - c_runtime + ) + exposed_delta = exposed_after - exposed_before + for gw_comm_time, gw_comp_time in group_waits.values(): + gw_exposed_before = max(0, gw_comm_time - gw_comp_time) + gw_exposed_after = max( + 0, gw_comm_time - gw_comp_time + c_runtime + ) + + exposed_delta += gw_exposed_after - gw_exposed_before + + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}," + f" group contains waits, total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gw, ( + gw_comm_time, + gw_comp_time, + ) in group_waits.items(): + group_waits[gw] = ( + gw_comm_time, + gw_comp_time - c_runtime, + ) + else: + # Candidate is async_collective + + # Unsafe collectives reordering + # Cj -> [...group_runtime..., Ci] -> Wj + # Checking that we are not increasing exposed time of Cj + if group_runtime > 0: + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + exposed_delta = exposed_after - exposed_before + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate {candidate.get_name()} is collective," + f" group_runtime:{group_runtime}," + f" exposed_delta:{exposed_delta} c_comm_time:{comm_time} c_comp_time:{comp_time}" + ) + break + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem: int = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # candidate and one of group nodes are successors of the same buffer + # and last use of the buffer happen in group nodes. + # This last use deallocates it. + # If we swap [candidate [group]] to [[group] candidate], + # candidate becomes the last use + # and deallocated this buffer instead of group node. + # we need to update size_free accordingly to group_node and candidate, + # and recalculate post_alloc, post_free for them. + # + # Buf that changes its last use snode, + # after swap will be deallocated only by candidate, + # while before it was deallocated by group node. + group_n_to_bufs_after_swap_dealloc_by_candidate = ( + _find_buffers_with_changed_last_use( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update = ( + _calculate_potential_peak_memory_reorder( + candidate, + gns, + group_tail, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _curr_memory, + ) + ) + + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.reorder_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + info.moves += 1 + total_moves += 1 + + _head = _perform_double_linked_list_swap( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + info.final_exposed = comm_time - comp_time + + _update_memory_tracking_after_swap_reorder( + candidate, + gns, + group_tail, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _post_alloc_update, + _curr_memory, + buf_to_snode_last_use, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + # Compare iteratively recomputed memory data + # with full run of estimate_peak_memory + + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "reorder_communication_preserving_peak_memory", + group_n_to_bufs_after_swap_dealloc_by_candidate, + ) + if iterative_recompute_error: + break + candidate = _prev[group_head] + curr = _next_curr + + new_snodes = _format_and_log_reordering_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + return new_snodes, stats @@ -875,344 +1324,295 @@ class SinkWaitInfo: moves: int = 0 moves_info: str = "" limiting_factor: str = "None" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" + @property + def improvement(self): + return self.initial_exposed - self.final_exposed -def _sink_waits_iterative_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode - - original_snodes_num = len(snodes) - if original_snodes_num == 0: - return snodes, {} - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) - _prev, _next, _head = _initialize_double_linked_list(snodes) +def _is_node_groupable_for_sink_waits( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped during sink_waits pass. - stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} + Sink Waits traverses waits right to left, so we don't group with + processed waits on the right or with async collectives. - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret + Args: + candidate: Node to check for groupability - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate - ): - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # Sink Waits traverse Waits right to left, + # => we do not group with processed Waits on the right. + if contains_wait(candidate): + return False, f"candidate contains wait {candidate.get_name()}" + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_async_collective {candidate.get_name()}", ) - # Stash memory tracing updates to not recompute them after swap - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - _size_free_delta_update: dict[BaseSchedulerNode, int] = {} - - potential_peak = 0 - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - # Not accounting for buffers liveliness change - potential_peak = max( - group_peak_memory + candidate_delta_mem, - pre_group_mem + candidate_allocfree.size_alloc, + + # pyrefly: ignore[unbound-name] + if not config_comms.sink_iterative_use_runtime_estimations: + # Heuristics pre-use_runtime_estimations: + # TODO(ivankobzarev): Remove them after confirming, + # that using runtime estimations always give better results. + # We do not want to group with collectives to not reorder them forward. + if contains_collective(candidate): + return ( + False, + f"candidate contains collective {candidate.get_name()}", ) - return potential_peak, _post_alloc_update, _size_free_delta_update + if contains_gemm_like(candidate): + return ( + False, + f"candidate contains gemm_like {candidate.get_name()}", + ) + return True, None + + +def _update_memory_tracking_after_swap_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + size_free_delta_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (sink_waits version). + Updates curr_memory and snodes_allocfree dictionaries to reflect the new + memory state after swapping candidate with group. + + Args: + candidate: Node that was moved + gns: Group nodes + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + post_alloc_update: Cached post-allocation memory values + size_free_delta_update: Cached size-free delta values + curr_memory: Current memory state dict (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + group_head = gns[0] + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _post_alloc_update[candidate] = candidate_post_alloc - potential_peak = candidate_post_alloc - candidate_size_free_to_move = sum( - buf.mpi_buffer.size_free # type: ignore[attr-defined] - for buf in itertools.chain.from_iterable( - group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() - ) + curr_memory[candidate] = ( + candidate_post_alloc, + candidate_post_alloc - candidate_allocfree.size_free, ) - _size_free_delta_update[candidate] = -candidate_size_free_to_move - delta_mem = candidate_delta_mem + candidate_size_free_to_move for gn in gns: - gn_post_alloc = _curr_memory[gn][0] + delta_mem - _post_alloc_update[gn] = gn_post_alloc - potential_peak = max(potential_peak, gn_post_alloc) - gn_size_free_to_add = 0 - if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] - for buf in bufs: - gn_size_free_to_add += buf.mpi_buffer.size_free - _size_free_delta_update[gn] = gn_size_free_to_add - delta_mem -= gn_size_free_to_add - return potential_peak, _post_alloc_update, _size_free_delta_update - - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next - # 0: - group_head_prev = _prev[group_head] - if group_head_prev: - _next[group_head_prev] = candidate - _prev[candidate] = group_head_prev - - # 2: - candidate_next = _next[candidate] - if candidate_next: - _prev[candidate_next] = group_tail - _next[group_tail] = candidate_next - - # 1: - _prev[group_head] = candidate - _next[candidate] = group_head - nonlocal _head - if group_head == _head: - _head = candidate - - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ): - group_head = gns[0] - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc - ) - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _curr_memory[candidate] = ( - candidate_post_alloc, - candidate_post_alloc - candidate_allocfree.size_free, - ) - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] + candidate_delta_mem, - cm[1] + candidate_delta_mem, - ) - return - - for n in [candidate, *gns]: - post_alloc = _post_alloc_update[n] - snodes_allocfree[n].size_free += _size_free_delta_update[n] - _curr_memory[n] = ( - post_alloc, - post_alloc - snodes_allocfree[n].size_free, + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] + candidate_delta_mem, + cm[1] + candidate_delta_mem, ) + return - curr = snodes[-1] - - processed_waits = OrderedSet() # type: ignore[var-annotated] - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute - debug_num_sink_waits_to_reorder: Optional[int] = ( - config.sink_waits_iterative_debug_limit_to_sink - ) - - iterative_recompute_error = False - - while _prev[curr] is not None: - if iterative_recompute_error: - break - if ( - debug_num_sink_waits_to_reorder is not None - and len(processed_waits) >= debug_num_sink_waits_to_reorder - ): - break + for n in [candidate, *gns]: + post_alloc = post_alloc_update[n] + snodes_allocfree[n].size_free += size_free_delta_update.get(n, 0) + curr_memory[n] = ( + post_alloc, + post_alloc - snodes_allocfree[n].size_free, + ) - # pyrefly: ignore [bad-argument-type] - if contains_wait(curr) and curr not in processed_waits: - processed_waits.add(curr) - info = stats[curr] = SinkWaitInfo() - candidate = _next[curr] - wait_snode = curr - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] - while candidate is not None: - if iterative_recompute_error: - break - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - wait_snode.scheduler, - gns, - temp_grouping=True, - ) - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d - for d in candidate.unmet_dependencies - if not _is_fake_dep(d) - } - - group_outs = group.get_outputs() - data_dep = None - for o in group_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break - # 1. If we have data_dep - we can not swap => trying to group - # 2. If swap candidate and current node both contain collectives => trying to group - if data_dep is not None or ( - both_contain_comms := ( - contains_collective(group) and contains_collective(candidate) - ) - ): +def _calculate_potential_peak_memory_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_head: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + curr_memory: dict, + snodes_allocfree: dict, +) -> tuple[int, dict[BaseSchedulerNode, int], dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (sink_waits version). - def is_groupable(snode): - # We do not want to group with collectives to not reorder them forward. - if contains_collective(snode): - return ( - False, - f"candidate contains collective {snode.get_name()}", - ) - if contains_gemm_like(snode): - return ( - False, - f"candidate contains gemm_like {snode.get_name()}", - ) - return True, None + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation and size-free delta values. - is_grp, grp_reason = is_groupable(candidate) - if is_grp: - group_tail = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _next[candidate] - continue + Args: + candidate: Node being moved + gns: Group nodes + group_head: First node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + curr_memory: Current memory state dict + snodes_allocfree: Allocation/free info for all nodes - elif (data_dep is None) and both_contain_comms: - info.limiting_factor = ( - f"collective ordering {_group_names(gns)}" - f" with candidate:{candidate.get_name()}" - ) - break - else: - info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" - f"dep on {gns}" - f"\n outs:{[o.get_name() for o in group_outs]}" - f"\n non_group_reason:{grp_reason}" - ) - break - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # [group] candidate -> candidate [group] - # Check for buffers with successors in group and candidate last successor - # - # Buf that changes its last use snode, - # It was deallocated by candidate, - # but after swap it will be deallocated by group node. - group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if snode_last_use != candidate: # noqa: E711 - continue - # candidate is last use of buf - last_succ_gn = None - for gn in gns: - if gn in succ_nodes: - last_succ_gn = gn - if last_succ_gn is None: - continue + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict, size_free_delta_update_dict) + """ + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + # Stash memory tracing updates to not recompute them after swap + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + _size_free_delta_update: dict[BaseSchedulerNode, int] = {} + + potential_peak = 0 + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + # Not accounting for buffers liveliness change + potential_peak = max( + group_peak_memory + candidate_delta_mem, + pre_group_mem + candidate_allocfree.size_alloc, + ) + return potential_peak, _post_alloc_update, _size_free_delta_update - # gn has successors of buf that after potential swap will become - # last use of buf and start deallocating buf instead of candidate - group_n_to_bufs_after_swap_dealloc_instead_of_candidate[ - last_succ_gn - ].append(buf) - - potential_peak, _post_alloc_update, _size_free_delta_update = ( - _calculate_potential_peak_memory( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - ) - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break + candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc + _post_alloc_update[candidate] = candidate_post_alloc + potential_peak = candidate_post_alloc + candidate_size_free_to_move = sum( + buf.mpi_buffer.size_free # type: ignore[attr-defined] + for buf in itertools.chain.from_iterable( + group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() + ) + ) + _size_free_delta_update[candidate] = -candidate_size_free_to_move + delta_mem = candidate_delta_mem + candidate_size_free_to_move + for gn in gns: + gn_post_alloc = curr_memory[gn][0] + delta_mem + _post_alloc_update[gn] = gn_post_alloc + potential_peak = max(potential_peak, gn_post_alloc) + gn_size_free_to_add = 0 + if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] + for buf in bufs: + gn_size_free_to_add += buf.mpi_buffer.size_free + _size_free_delta_update[gn] = gn_size_free_to_add + delta_mem -= gn_size_free_to_add + return potential_peak, _post_alloc_update, _size_free_delta_update + + +def _perform_double_linked_list_swap_sink_waits( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list (sink_waits version). - info.moves += 1 - info.moves_info += f"+{candidate.get_name()}" + Transforms (moves candidate to the left): + group_head_prev -> group_head...group_tail -> candidate -> candidate_next + Into: + group_head_prev -> candidate -> group_head...group_tail -> candidate_next - _perform_double_linked_list_swap(candidate, group_head, group_tail) + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ) + Returns: + New head of the linked list (may change if group_head was the head) + """ + # 0: Update group_head's previous node + group_head_prev = prev_dict[group_head] + if group_head_prev: + next_dict[group_head_prev] = candidate + prev_dict[candidate] = group_head_prev + + # 2: Update candidate's next node + candidate_next = next_dict[candidate] + if candidate_next: + prev_dict[candidate_next] = group_tail + next_dict[group_tail] = candidate_next + + # 1: Link candidate to group_head + prev_dict[group_head] = candidate + next_dict[candidate] = group_head + + # Update head if group_head was the head + if group_head == head: + return candidate + return head + + +def _format_and_log_sink_waits_stats( + stats: dict[BaseSchedulerNode, SinkWaitInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format sink_waits statistics, log them, and return final node list. - if debug_iterative_memory_recompute: - from .comms_debug import _debug_iterative_memory_recompute - - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "sink_waits_iterative", - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - if iterative_recompute_error: - break + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - candidate = _next[group_tail] - curr = _prev[curr] # type: ignore[assignment] + Args: + stats: Per-node sink_waits statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ headers = [ "Wait node", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", + "limiting factor", "grouped", "grouped_info", "moves", "moves_info", - "limiting factor", + "overlap_info", ] rows = [ [ node_summary(snode), + info.comm_time / 1e3, + info.comp_time / 1e3, + info.initial_exposed / 1e3, + info.final_exposed / 1e3, + info.improvement / 1e3, + info.limiting_factor, info.grouped, info.grouped_info, info.moves, info.moves_info, - info.limiting_factor, + info.overlap_info, ] for snode, info in stats.items() ] log_str = "" if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate log_str += tabulate( @@ -1224,7 +1624,7 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -1239,18 +1639,409 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - return new_snodes, stats + return new_snodes + + +def _find_buffers_with_changed_last_use_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping in sink_waits pass. + + When we swap [group] candidate to candidate [group], some buffers that + were last used by candidate will now be last used by a group node instead. + This is the opposite direction from the reorder version. + Args: + candidate: The node being moved (currently last use) + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes + + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if snode_last_use != candidate: # noqa: E711 + continue + # candidate is last use of buf + last_succ_gn = None + for gn in gns: + if gn in succ_nodes: + last_succ_gn = gn + if last_succ_gn is None: + continue -def sink_waits_iterative( + # gn has successors of buf that after potential swap will become + # last use of buf and start deallocating buf instead of candidate + group_n_to_bufs_after_swap_dealloc_instead_of_candidate[last_succ_gn].append( + buf + ) + + return group_n_to_bufs_after_swap_dealloc_instead_of_candidate + + +def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], -) -> list[BaseSchedulerNode]: +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + + curr: Optional[BaseSchedulerNode] = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + debug_num_sink_waits_to_reorder: Optional[int] = ( + config_comms.sink_waits_iterative_debug_limit_to_sink + ) + + iterative_recompute_error = False + while curr is not None and _prev[curr] is not None: + _prev_curr = _prev[curr] + if iterative_recompute_error: + break + if ( + debug_num_sink_waits_to_reorder is not None + and len(processed_waits) >= debug_num_sink_waits_to_reorder + ): + break + + # pyrefly: ignore [bad-argument-type] + if not (contains_wait(curr) and curr not in processed_waits): + curr = _prev_curr + continue + + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_colls = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] + + while candidate is not None: + if config_comms.sink_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.sink_iterative_extra_comm_comp_overlap * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + wait_snode.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in candidate.unmet_dependencies if not _is_fake_dep(d) + } + + group_outs = group.get_outputs() + data_dep = None + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + # Conservative sink wait, limiting by space before next collective. + # The global strategy is that bucketing should create space. + # For 2D we can experiment with allowing to sink Wait beyond non current group collective. + # pyrefly: ignore[unbound-name] + if not config_comms.sink_waits_iterative_swap_with_collectives: + if contains_async_collective(candidate): + info.limiting_factor = ( + f"candidate contains_async_collective {candidate.get_name()}" + ) + break + + # 1. If we have data_dep - we can not swap => trying to group + # 2. If swap candidate and current node both contain collectives => trying to group + if data_dep is not None or ( + both_contain_comms := ( + contains_collective(group) and contains_collective(candidate) + ) + ): + _is_groupable, groupable_reason = _is_node_groupable_for_sink_waits( + candidate + ) + if _is_groupable: + group_tail = candidate + if ( + # pyrefly: ignore[unbound-name] + config_comms.sink_iterative_use_runtime_estimations + and contains_collective(candidate) + ): + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + group_colls[candidate] = (comm_time, comp_time) + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _next[candidate] + continue + elif data_dep is None: + if ( + # pyrefly: ignore[unbound-name] + not config_comms.sink_waits_iterative_unsafe_collectives_reorder + and both_contain_comms + ): + info.limiting_factor = ( + f"collective ordering {_group_names(gns)}" + f"\n with candidate:{candidate.get_name()}" + ) + break + else: + info.limiting_factor = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"\n dep on {_group_names(gns)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{groupable_reason}" + ) + break + + # pyrefly: ignore[unbound-name] + if config_comms.sink_iterative_use_runtime_estimations: + if is_wait(candidate.node): + # Corresponding collective is before the group, + # Swap can increase exposed time of corresponding collective + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), runtimes + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + # We do not know how much we can sink more after this swap, + # Just comparing advantage at the moment for now. + if exposed_after > exposed_before: + info.limiting_factor = ( + "candidate is wait," + f" exposed_before:{exposed_before} vs exposed_after:{exposed_after}" + ) + break + + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + # If candidate has sync runtime, + # Waits of gorup_colls are on the right from group. + # Swap can increase their exposed time. + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_colls) > 0: + # Advantage for current Wait to do the Swap + # pyrefly: ignore[no-matching-overload] + exposed_delta = max( + 0, + info.comm_time - info.comp_time, + ) + # pyrefly: ignore[no-matching-overload] + -max(0, info.comm_time - info.comp_time - c_runtime) + for gc, (gc_comm_time, gc_comp_time) in group_colls.items(): + exposed_delta += max(0, gc_comm_time - gc_comp_time) - max( + 0, gc_comm_time - gc_comp_time + c_runtime + ) + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}, group contains collectives," + f" total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gc, ( + gc_comm_time, + gc_comp_time, + ) in group_colls.items(): + group_colls[gc] = ( + gc_comm_time, + gc_comp_time - c_runtime, + ) + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # [group] candidate -> candidate [group] + # Check for buffers with successors in group and candidate last successor + # + # Buf that changes its last use snode, + # It was deallocated by candidate, + # but after swap it will be deallocated by group node. + group_n_to_bufs_after_swap_dealloc_instead_of_candidate = ( + _find_buffers_with_changed_last_use_sink_waits( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update, _size_free_delta_update = ( + _calculate_potential_peak_memory_sink_waits( + candidate, + gns, + group_head, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _curr_memory, + snodes_allocfree, + ) + ) + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.sink_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + + info.moves += 1 + info.moves_info += f"+{candidate.get_name()}" + + _head = _perform_double_linked_list_swap_sink_waits( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + _update_memory_tracking_after_swap_sink_waits( + candidate, + gns, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _post_alloc_update, + _size_free_delta_update, + _curr_memory, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "sink_waits_iterative", + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + ) + if iterative_recompute_error: + break + + candidate = _next[group_tail] + curr = _prev_curr + + new_snodes = _format_and_log_sink_waits_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + + return new_snodes, stats + + +def sink_waits_iterative(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Similarly to reorder_communication_preserving_peak_memory this pass will try to iteratively + push Wait nodes later, recomputing estimated peak memory before each swap, + and preventing peak memory regressions. + + Pass will be applied to every Wait node. If there are immediate dependencies with next node, + pass will try to group them together and on the next step to swap the group with next candidate. + + If _inductor.config_comms.sink_iterative_use_runtime_estimations is set True, + pass will stop reordering of Wait once corresponding Collective is unexposed, + based on runtime estimations. + + inductor.config_comms.sink_iterative_peak_memory_budget allows to tune how much pass + can regress initial peak memory. + E.g.: + sink_iterative_peak_memory_budget == 0.0 - No regression of initial peak memory is allowed + sink_iterative_peak_memory_budget == 0.2 - Pass can improve comm-compute overlap, sacrificing + 20% of initial peak memory value. + + inductor.config_comms.sink_iterative_extra_comm_comp_overlap config allows to more aggressively + sink waits, stopping only when overlap_compute >= (1 + extra_comm_comp_overlap) * comm_time + """ return _sink_waits_iterative_internal(snodes)[0] def estimate_op_runtime(snode: BaseSchedulerNode) -> float: """ - Returns estimated op runtime in nanoseconds (ns) + Returns estimated op runtime in milliseconds (ms) """ if config.estimate_op_runtime == "default": runtime = snode.get_estimated_runtime() @@ -1267,7 +2058,7 @@ def node_summary(snode): if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" - detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 037b0e438ada..a4114644026c 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -24,6 +24,7 @@ import torch._thread_safe_fork # noqa: F401 from torch._inductor import config from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.timer import Timer from torch._inductor.compile_worker.tracked_process_pool import ( TrackedProcessPoolExecutor, ) @@ -132,6 +133,7 @@ def __init__( nprocs: int, pickler: Optional[SubprocPickler] = None, kind: SubprocKind = SubprocKind.FORK, + quiesce: bool = False, ) -> None: entry = os.path.join(os.path.dirname(__file__), "__main__.py") self.pickler = pickler or SubprocPickler() @@ -216,6 +218,13 @@ def __init__( "pytorch.wait_counter.subproc_pool.first_job" ).guard() + if quiesce: + self.timer: Optional[Timer] = Timer( + config.quiesce_async_compile_time, self.quiesce + ) + else: + self.timer = None + # Start thread last to ensure all member variables are initialized # before any access. self.read_thread.start() @@ -288,6 +297,8 @@ def _read_thread(self) -> None: with self.futures_lock: if not self.running: return + if self.timer: + self.timer.record_call() if isinstance(result, _SubprocExceptionInfo): # An exception occurred in the submitted job self.pending_futures[job_id].set_exception( @@ -322,6 +333,8 @@ def shutdown(self) -> None: with self.write_lock: if not self.running: return + if self.timer: + self.timer.quit() self.running = False self.running_waitcounter.__exit__() _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index d4b0c0dc9e28..7c495403b3a5 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Callable, Optional, Union +from typing import Optional, Union class Timer: @@ -17,7 +18,7 @@ def __init__( self.background_thread: Optional[Thread] = None self.last_called: Optional[float] = None self.duration = duration - self.sleep_time = 60 + self.sleep_time = duration / 2 self.call = call self.exit = False diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b78ade758f80..2d9e180db54f 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -379,6 +379,15 @@ def prologue_fusion_enabled() -> bool: # for built-in passes, use string name; for user-defined passes, pass in the function handle # WARNING: Inductor scheduler IR is at prototype stage and subject to change, # hence custom IR passes built on top of it might break in the future. +# +# See aten_distributed_optimizations, it is recommended way for distributed optimizations. +# +# Recommended configuration for reorder_for_compute_comm_overlap_passes: +# [ +# "reorder_communication_preserving_peak_memory", +# "sink_waits_iterative", +# "reorder_communication_preserving_peak_memory", +# ] reorder_for_compute_comm_overlap_passes: list[ Union[ str, @@ -387,11 +396,7 @@ def prologue_fusion_enabled() -> bool: list["torch._inductor.scheduler.BaseSchedulerNode"], ], ] -] = [ - "reorder_compute_for_overlap", - "sink_waits", - "raise_comms", -] +] = [] # Maximum number of positions to advance a given collective, unlimited by default reorder_prefetch_limit: Optional[int] = None @@ -407,16 +412,6 @@ def prologue_fusion_enabled() -> bool: # is zero, which turns off this optimization. size_threshold_for_succ_based_strategy: int = 0 -reorder_iterative_debug_memory_recompute: bool = False -reorder_iterative_debug_limit_to_reorder: Optional[int] = ( - None - if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None - else int(env_str) -) -sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( - # pyrefly: ignore [unbound-name] - None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) -) bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used @@ -546,10 +541,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False @@ -678,6 +669,17 @@ def use_autoheuristic(name: str) -> bool: == "1" ) + +# When trying to fuse two nodes, one with: +# a[contiguous_writes] = fn(...) +# and another node: +# b[contiguous_writes] = a[discontiguous_reads] +# If b is unary, and we can figure out an inverse formula for +# discontiguous writes, invert b as : +# b[inverse(discontiguous_writes)] = a[contiguous_reads] +# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 +loop_index_inversion_in_fusion: bool = True + # If fusing two nodes only save less then score_fusion_memory_threshold memory, # we should not bother fusing the nodes. # @@ -964,6 +966,11 @@ def decide_compile_threads() -> int: default=False, ) +# Time in seconds to wait before quiescing +quiesce_async_compile_time: int = Config( + default=60, +) + # Whether or not to enable statically launching CUDA kernels # compiled by triton (instead of using triton's own launcher) use_static_cuda_launcher: bool = static_cuda_launcher_default() @@ -1949,8 +1956,9 @@ class rocm: # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" -# Backend to use for CUDA codegen either "triton" or "halide" (experimental) -cuda_backend: Literal["triton", "halide"] = "triton" +# Backend to use for CUDA codegen either +# "triton", "halide" (experimental) or "pallas" (experimental) +cuda_backend: Literal["triton", "halide", "pallas"] = "triton" # Backend to use for XPU codegen either "triton" xpu_backend: Literal["triton"] = "triton" diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py index b5dbf424f35b..31f38b867dd5 100644 --- a/torch/_inductor/config_comms.py +++ b/torch/_inductor/config_comms.py @@ -1,4 +1,6 @@ +import os import sys +from typing import Optional from torch.utils._config_module import install_config_module @@ -11,5 +13,59 @@ # decisions on different distributed ranks. runtime_estimations_align_across_all_distributed_ranks: bool = False +reorder_iterative_debug_memory_recompute: bool = False +reorder_iterative_debug_limit_to_reorder: Optional[int] = ( + None + # pyrefly: ignore[unbound-name] + if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None + else int(env_str) +) +sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( + # pyrefly: ignore[unbound-name] + None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) +) + + +# Should be used with config.runtime_estimations_mms_benchmark = True +reorder_iterative_use_runtime_estimations: bool = False +sink_iterative_use_runtime_estimations: bool = False + +# Broadcast runtime estimations doing real Collective operation between all ranks. +# If non-deterministic runtime estimations are used this must be used to make +# all ranks to do identical decisions and prevent global Collectives reordering, +# (that will result un NCCL hangs) +reorder_for_compute_comm_overlap_broadcast_runtime_estimations: bool = False + +# Block of Ratios to workaround imperfection of current runtime estimations +# for collectives and compute for different scenarios. +# Multiplier of collectives estimated durations +reorder_sink_runtime_estimations_comm_mult: float = 2.0 +# Multiplier of compute estimated durations +reorder_sink_runtime_estimations_non_comm_mult: float = 1.0 +# The reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive overlap +reorder_iterative_extra_comm_comp_overlap: float = 0.5 +# The sink waits reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive sink waits +sink_iterative_extra_comm_comp_overlap: float = 0.5 + +# Allow reorder iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +reorder_iterative_peak_memory_budget: float = 0.2 +# Allow sink waits iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +sink_iterative_peak_memory_budget: float = 0.2 + +# Experimental unsafe configuration that allows changing relative collectives order. +# Must be used with runtime_estimations_align_across_all_distributed_ranks = True +reorder_iterative_unsafe_collectives_reorder: bool = True +sink_waits_iterative_unsafe_collectives_reorder: bool = True + +# Allow group and move other collectives during reordering +reorder_iterative_group_with_collectives: bool = False +sink_waits_iterative_swap_with_collectives: bool = False + # adds patch, save_config, etc install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 515f628c9938..1c4a394d1eb2 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -430,7 +430,7 @@ def get_isa_from_cpu_capability( "avx2": "avx2", "avx512": "avx512", } - if capability in capability_to_isa_str.keys(): + if capability in capability_to_isa_str: # pyrefly: ignore [index-error] isa_str = capability_to_isa_str[capability] if isa_str == "INVALID_VEC_ISA": diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 668becdded46..50d986d48e6c 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -192,7 +192,7 @@ def check_multiple_devices_or_any_cpu_nodes( ): return None - keys_repr = (repr(key) for key in device_node_mapping.keys()) + keys_repr = (repr(key) for key in device_node_mapping) return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index ab831c96c94b..29f070564349 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,7 +2,8 @@ import logging import operator from collections import defaultdict -from typing import Any, Callable, Literal, TypeAlias +from collections.abc import Callable +from typing import Any, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8a4de1a60486..44314b912786 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6b0c2ad2c94a..1e71c350ed7b 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index c6b7c51b948e..e887d4bf62c8 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 70b3a3c355dd..214d3bf02f7f 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, Callable +from typing import Any, TYPE_CHECKING import torch from torch._dynamo.utils import counters @@ -35,6 +35,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index a47aa960e58c..f383ab63dc26 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 30768fda9bb7..b511403d4874 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index f11817e1d4c5..91b4e10bf723 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,8 @@ import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch @@ -51,8 +52,8 @@ decode_device, get_all_devices, get_gpu_type, + has_uses_tagged_as, is_gpu, - is_pointwise_use, OPTIMUS_EXCLUDE_POST_GRAD, ) from ..virtualized import V @@ -1510,8 +1511,10 @@ def should_prefer_unfused_addmm(match): if not is_gpu(inp.meta["val"].device.type): return False - output = match.output_node() - return all(is_pointwise_use(use) for use in output.users) + return has_uses_tagged_as( + match.output_node(), + (torch.Tag.pointwise, torch.Tag.reduction), + ) @register_graph_pattern( diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 52222f3da834..e42e8a113977 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 92e1e6f375f4..0bad4fa7cc63 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,9 +4,8 @@ import operator import os from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2e89ea5ca461..28e7f88d3398 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1590,7 +1590,7 @@ def maybe_propagate( schema_kwargs = {arg.name: arg for arg in schema.arguments} - for key in old_kwargs.keys(): + for key in old_kwargs: old_arg = old_kwargs[key] new_arg = new_kwargs[key] schema_arg = schema_kwargs[key] diff --git a/torch/_inductor/invert_expr_analysis.py b/torch/_inductor/invert_expr_analysis.py new file mode 100644 index 000000000000..816482dba020 --- /dev/null +++ b/torch/_inductor/invert_expr_analysis.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from typing import Optional + +import sympy + +from torch._inductor.utils import _IntLike, argsort_sym +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from .virtualized import V + + +def static_eq(a: _IntLike, b: _IntLike) -> bool: + return V.graph.sizevars.statically_known_equals(a, b) + + +@dataclass +class Term: + coefficient: _IntLike + range: Optional[_IntLike] # None for unbounded + original_expr: sympy.Expr + reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction + + +def generate_inverse_formula( + expr: sympy.Expr, var: sympy.Symbol +) -> Optional[sympy.Expr]: + """ + Analyze an expression to see if it matches a specific invertible pattern that we + know how to reverse. + + We're looking for expressions that are sums of terms where each term extracts a + distinct bounded range from the input variable, like: + + y = cβ‚€*aβ‚€ + c₁*a₁ + cβ‚‚*aβ‚‚ + ... + cβ‚™*aβ‚™ + + where each aα΅’ must be one of these specific patterns: + - ModularIndexing(var, divisor, modulo) + - FloorDiv(ModularIndexing(var, 1, modulo), divisor) + - FloorDiv(var, divisor) + - var (the variable itself) + + The key pattern we need is: + - Coefficients are strictly decreasing: cβ‚€ > c₁ > cβ‚‚ > ... > cβ‚™ + - Each coefficient matches the product of ranges of later terms (mixed-radix property) + - Each term extracts a bounded range, creating non-overlapping "slots" + + If we find this pattern, we can generate the reconstruction transformation that + decomposes the variable and rebuilds it using the correct multipliers. + + EXAMPLE: + Input: 100*((p//100)) + 10*((p%100)//10) + (p%10) + + Returns the reconstruction expression: + remainderβ‚€ = p + componentβ‚€ = remainderβ‚€ // 100 # hundreds digit + remainder₁ = remainderβ‚€ % 100 + component₁ = remainder₁ // 10 # tens digit + remainderβ‚‚ = remainder₁ % 10 + componentβ‚‚ = remainderβ‚‚ # ones digit + result = componentβ‚€*100 + component₁*10 + componentβ‚‚*1 + + This decomposes p into its components and rebuilds it using the original + multipliers, which should equal the input expression. + + Args: + expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.) + var: The variable being decomposed + + Returns: + None if not invertible, or the reconstruction expression + + References: + Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix + """ + # Step 1: Parse all terms + terms = parse_terms(expr, var) + if not terms: + return None + + # Step 2: Sort by coefficient (descending) + coeffs = [t.coefficient for t in terms] + idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs)) + terms = [terms[i] for i in idxs] + + # Step 3: Check invertibility conditions + if not check_invertibility(terms): + return None + + return generate_reconstruction_expr(terms, var) + + +def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]: + """Parse expression into terms.""" + if not isinstance(expr, sympy.Add): + # Single term + term = parse_single_term(expr, var) + return [term] if term else [] + + terms = [] + for arg in expr.args: + term = parse_single_term(arg, var) + if term: + terms.append(term) + else: + return None # If any term fails to parse, fail completely + + return terms + + +def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]: + """Parse a single term and extract coefficient, range, and reconstruction multiplier.""" + # Extract coefficient and expression parts + coefficient, expr_parts = term.as_coeff_mul() + + if len(expr_parts) == 0: + # Pure constant term + return Term( + coefficient=coefficient, + range=1, + original_expr=1, + reconstruction_multiplier=0, + ) + elif len(expr_parts) == 1: + expr = expr_parts[0] + else: + # Multiple non-constant factors, too complex + return None + + # Now determine the range and reconstruction multiplier + range_val, reconstruction_multiplier = analyze_expression_properties(expr, var) + if reconstruction_multiplier is None: + return None + + return Term( + coefficient=coefficient, + range=range_val, + original_expr=expr, + reconstruction_multiplier=reconstruction_multiplier, + ) + + +def analyze_expression_properties( + expr: sympy.Expr, var: sympy.Symbol +) -> tuple[Optional[_IntLike], Optional[_IntLike]]: + """Analyze an expression to determine its range and reconstruction multiplier.""" + # ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo + if isinstance(expr, ModularIndexing): + x, div, mod = expr.args + if static_eq(x, var): + return mod, div # Range is mod, multiplier is div + + # FloorDiv cases + if isinstance(expr, FloorDiv): + base, divisor = expr.args + + # FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div + if isinstance(base, ModularIndexing): + x, inner_div, mod = base.args + if static_eq(x, var) and static_eq(inner_div, 1): + range_val = FloorDiv(mod, divisor) + return range_val, divisor # Range is mod//div, multiplier is div + + # FloorDiv(var, divisor) = var // divisor (unbounded) + elif static_eq(base, var): + return None, divisor # Unbounded range, multiplier is div + + return None, None + + +def check_invertibility(terms: list[Term]) -> bool: + """Check if the terms represent an invertible transformation.""" + if not terms: + return False + + # Coefficients must be strictly decreasing + coeffs = [t.coefficient for t in terms] + if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list( + reversed(range(len(coeffs))) + ): + return False + + # Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1] + expected_coeff = 1 + for term in reversed(terms): + if not static_eq(term.coefficient, expected_coeff): + return False + if term.range is not None: + expected_coeff *= term.range + + return True + + +def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr: + y = var + reconstruction = sympy.S.Zero + remainder = y + + for i, term in enumerate(terms): + if i < len(terms) - 1: + component = FloorDiv(remainder, term.coefficient) + remainder = ModularIndexing(remainder, 1, term.coefficient) + else: + # Last term should also divide by its coefficient + component = FloorDiv(remainder, term.coefficient) + + reconstruction += component * term.reconstruction_multiplier + + return reconstruction diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b1a3071cb7ba..53c12d072604 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1534,7 +1534,7 @@ def py_cnst(val: object) -> Union[bool, float, int]: # "all" is desugared to `!any(!val)` } - assert reduction_type in rtypes_to_inits.keys(), ( + assert reduction_type in rtypes_to_inits, ( f"{reduction_type} not supported for zero-dimension tensors!" ) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 303110a561b5..23878f757cc5 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,9 +2,11 @@ import functools import logging -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch +from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -157,7 +159,6 @@ def _adapt_user_input_gen_fns( Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. """ - from torch._inductor import config name_to_index = {name: i for i, name in enumerate(arg_names)} index_based_fns = {} @@ -237,6 +238,7 @@ def autotune_custom_op( This function generates multiple implementation choices for a custom operation and uses Inductor's autotuning system to select the best performing variant at runtime. + After selecting the best choice, applies inline fusion if the winning choice has a graph. Args: name: Unique identifier for the autotuning operation @@ -319,14 +321,34 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) - return autotune_select_algorithm( + # Run autotuning and get both result and winning choice + selected_result, winning_choice = autotune_select_algorithm( name=name, choices=choices, input_nodes=list(inputs), layout=choices[0].layout, input_gen_fns=input_gen_fns, + return_choice=True, ) + # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) + if winning_choice.gm is not None: + log.debug( + "Inlining winning choice: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes + + return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name) + + log.debug( + "Winning choice does not support inlining: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + return selected_result + def register_custom_op_autotuning( custom_op: torch._library.custom_ops.CustomOpDef, @@ -359,7 +381,7 @@ def my_attention(query, key, value, head_dim=32): "query": lambda fake: torch.randn_like(fake, device='cuda'), "key": lambda fake: torch.randn_like(fake, device='cuda'), "value": lambda fake: torch.randn_like(fake, device='cuda'), - } + }, ) """ from torch._library.custom_ops import CustomOpDef @@ -377,12 +399,12 @@ def my_attention(query, key, value, head_dim=32): raise TypeError(f"configs must be a list or tuple, got {type(configs)}") processed_configs = [] - for config in configs: - if isinstance(config, CustomOpConfig): - processed_configs.append(config) + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + processed_configs.append(cfg) else: raise TypeError( - f"Each config must be a CustomOpConfig object, got {type(config)}" + f"Each config must be a CustomOpConfig object, got {type(cfg)}" ) if not processed_configs: @@ -401,14 +423,12 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: decompositions = [] non_tensor_args = [] - for config in processed_configs: - decomp = config.get_decomposition(default_impl=default_impl) + for cfg in processed_configs: + decomp = cfg.get_decomposition(default_impl=default_impl) decompositions.append(decomp) # Merge config params with runtime kwargs (runtime takes precedence) - merged_kwargs = _merge_config_and_runtime_kwargs( - config.params, runtime_kwargs - ) + merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs) non_tensor_args.append(merged_kwargs) result = autotune_custom_op( diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index c100df84d5a7..0d3721aa730a 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,8 +3,9 @@ import functools import importlib +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2af..b95073e769f3 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 0a44b728a5a9..881c14fd43d0 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -19,25 +18,19 @@ TritonTemplate, ) from ..utils import ( - ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) -if ensure_cute_available(): - from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs - log = logging.getLogger(__name__) aten = torch.ops.aten @@ -520,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -726,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -801,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80..000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 53ae1d8f63f6..3921aa955a83 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -95,7 +95,6 @@ class LoopBody: """ indexing_exprs: dict[str, sympy.Expr] - indexing_exprs_name: dict[sympy.Expr, str] submodules: dict[str, Any] subblocks: dict[str, LoopBodyBlock] indirect_vars: list[sympy.Symbol] @@ -104,6 +103,9 @@ class LoopBody: memory_usage: dict[MemoryUsageType, list[MemoryEntry]] op_counts: collections.Counter[str] + # defined only temporarily + indexing_exprs_name: dict[sympy.Expr, str] + def __init__( self, fn, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index cc13f7990901..f6ad1028ca12 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7307,6 +7307,35 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): return list(map(TensorBox.create, result)) # type: ignore[call-overload] +def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): + """Process nodes from a FX graph by executing them through V.graph. + + This is a common pattern for executing a subgraph's nodes: + - Placeholder nodes are mapped to the provided args + - Output nodes return their result + - Other nodes are executed via V.graph.run_node + + """ + output = None + + for i, node in enumerate(graph_module.graph.nodes): + if node.op == "placeholder": + assert node not in V.graph.env + V.graph.env[node] = args[i] + continue + elif node.op == "output": + output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs) + else: + assert node not in V.graph.env + V.graph.env[node] = V.graph.run_node(node) + + if output is None: + raise RuntimeError("No output node found in graph") + + return output + + # Import the control_deps_op HOP for lowering from torch._inductor.fx_passes.control_dependencies import control_deps @@ -7334,21 +7363,11 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args): arg_offset = 2 # first two args (additional_deps, subgraph) assert len(args) + arg_offset == len(original_args) - output = None - operation_len = len(V.graph.operations) assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args) - for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): - if node.op == "placeholder": - assert node not in V.graph.env - V.graph.env[node] = args[i] - continue - elif node.op == "output": - args, kwargs = V.graph.fetch_args_kwargs_from_env(node) - output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) - else: - assert node not in V.graph.env - V.graph.env[node] = V.graph.run_node(node) + + # Process subgraph nodes using the shared helper + output = process_subgraph_nodes(subgraph_fn.graph_module, list(args)) assert output is not None and additional_deps diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 6f58b683ac22..ed223de71c07 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -229,7 +229,7 @@ def assign_memory_planning_info_for_scheduler_buffers( # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) - for buf_name in name_to_buf.keys(): + for buf_name in name_to_buf: name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( size_alloc=sched_buf_to_size[buf_name][0], size_free=sched_buf_to_size[buf_name][1], diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d592a8c8c00f..d9d92e363879 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Optional, Union -from typing_extensions import Concatenate, ParamSpec, Self, TypeVar +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index 748715d1631a..14e13f937dbb 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -1,6 +1,6 @@ import os +from collections.abc import Callable from functools import cache, partial -from typing import Callable import torch from torch._environment import is_fbcode diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 0758e1113401..03d295749367 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import override, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override from filelock import FileLock @@ -21,6 +21,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index e7e1f1adc362..8e8cd011e2d4 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import Generator, TYPE_CHECKING -from typing_extensions import Protocol, TypeAlias +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol from filelock import FileLock, Timeout @@ -21,6 +21,7 @@ if TYPE_CHECKING: + from collections.abc import Generator from threading import Lock diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 341475ef1d6f..7ea22bdcddf0 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -5,6 +5,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING +from torch.utils._ordered_set import OrderedSet + from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -54,6 +56,7 @@ def __init__( name="unknown", size_hints=None, inductor_meta=None, + frozen_fields=None, ): self.is_mm = is_mm # we will tune num_stages for mm @@ -66,6 +69,9 @@ def __init__( self.name = name self.size_hints = size_hints self.inductor_meta = inductor_meta or {} + self.frozen_fields: OrderedSet[str] = ( + OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet() + ) def get_config_max(self, prefix: str) -> int: max_block = TRITON_MAX_BLOCK[prefix.upper()] @@ -117,7 +123,7 @@ def tunable_fields(self): out.append("num_stages") out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul - return out + return [f for f in out if f not in self.frozen_fields] def value_too_large(self, name: str, val: int) -> bool: block_suffix = "BLOCK" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index fe6788fb21e9..2e0a0dba9092 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -336,6 +336,7 @@ def __init__( name=self.fn.__name__, size_hints=size_hints, inductor_meta=self.inductor_meta, + frozen_fields=self.get_coordesc_frozen_fields(), ) self.filename = filename @@ -365,6 +366,13 @@ def __init__( # Mode for launch grid calculation self.grid_mode: Literal["python", "cpp"] = "python" + def get_coordesc_frozen_fields(self) -> OrderedSet[str]: + out: OrderedSet[str] = OrderedSet() + if self.inductor_meta.get("RSPLIT_SIZE"): + # We fix XBLOCK for mix order reduction + out.add("XBLOCK") + return out + def is_statically_launchable(self): """ Checks if every compiled kernel is statically launchable, which @@ -1843,6 +1851,8 @@ def make_launcher(self) -> LauncherType: else ( (binary.metadata.num_ctas, *binary.metadata.cluster_dims) if hasattr(binary, "metadata") + and hasattr(binary.metadata, "num_ctas") + and hasattr(binary.metadata, "cluster_dims") else () ) ), @@ -3578,13 +3588,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index df1d2f729b34..2930a33b465a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3345,7 +3345,10 @@ def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: ) break - if config.loop_ordering_after_fusion: + if ( + config.loop_ordering_after_fusion + or config.loop_index_inversion_in_fusion + ): nodes = self.fuse_nodes_once(nodes, is_reorder_round=True) return nodes @@ -4302,6 +4305,148 @@ def decide_fusion_fail_reason( return str(reasons) + def shared_data_after_inverting_indexing( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Attempts to enable fusion between two nodes by inverting indexing patterns. + + This optimization targets cases where node1 has a contiguous write and + node2 has a contiguous write but discontiguous read. By inverting the + indexing in node2's read and write operations, we can make them compatible + with node1 for potential fusion. + + Args: + node1: First scheduler node (source) + node2: Second scheduler node (target for inversion) + + Returns: + int: Fusion score if successful, 0 if optimization not applicable + """ + + if not config.loop_index_inversion_in_fusion: + return -1 + + if any(n.is_cpu() for n in [node1, node2]): + return -1 + + # Check for shared buffers between nodes + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + common_buffer_names = node1_buffer_names & node2_buffer_names + + if not common_buffer_names: + return -1 + + # only invert if node1 is single unmet dep + node2_unmet_dependencies = OrderedSet( + dep.name for dep in node2.unmet_dependencies + ) + if node2_unmet_dependencies - node1_buffer_names: + return -1 + + if len(node2_unmet_dependencies) > 1: + return -1 + + # Currently only handle single read/write operations + if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1: + return -1 + + node2_read = next(iter(node2.read_writes.reads)) + node2_write = next(iter(node2.read_writes.writes)) + + if not isinstance(node2_read, MemoryDep) or not isinstance( + node2_write, MemoryDep + ): + return -1 + + node1_writes = {dep.name: dep for dep in node1.read_writes.writes} + if node2_read.name not in node1_writes: + return -1 + + node1_write = node1_writes[node2_read.name] + + if not isinstance(node1_write, MemoryDep): + return -1 + + # We are checking for compatibility with the normalized node1 write + # then modifying node2 reads/writes. since the node1 write will be just used + # for compatibility, while node2 will be used in actual modification, just + # normalize node1 not node2. + node1_write = node1_write.normalize() + + if ( + node1_write.index != node2_write.index + and node1_write.size != node2_write.size + ): + return -1 + + if node2_read.size != node2_write.size or len(node2_read.var_names) != 1: + return -1 + + # Verify we have exactly two indexing expressions (one read, one write) + if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined] + return -1 + + # No subblocks allowed for this optimization + if node2._body.subblocks: # type: ignore[attr-defined] + return -1 + + assert ( + "index0" in node2._body.indexing_exprs # type: ignore[attr-defined] + and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined] + ) + + # Extract and verify single read expression + node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined] + if len(node2_read_exprs) != 1: + return -1 + + read_expr = next(iter(node2_read_exprs)) + + # Determine which index is for reading vs writing + if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined] + read_expr_index = "index0" + write_expr_index = "index1" + else: + assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined] + read_expr_index = "index1" + write_expr_index = "index0" + + from torch._inductor.invert_expr_analysis import generate_inverse_formula + + index_vars = node2._body.vars[0] # type: ignore[attr-defined] + if len(index_vars) != 1: + return -1 + + simplified_terms = [] + for term in sympy.Add.make_args(read_expr): + simplified_terms.append( + V.graph.sizevars.combine_modular_indexing_pairs(term) + ) + simplified_read_expr = sum(simplified_terms) + + inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0]) + + # formula is not invertible + if inverse_formula is None: + return -1 + + # === Apply Inversion === + + # Swap the indexing expressions using the inverse formula + node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined] + write_expr_index + ] + node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined] + + # Refresh dependencies and calculate fusion score + node2.refresh_dependencies(True, False) # type: ignore[attr-defined] + score = self.score_fusion_memory(node1, node2) + + fusion_log.info("Shared memory after inversion: %d", score) + return score + def shared_data_after_reordering_loop( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> int: @@ -4686,6 +4831,7 @@ def can_fuse( del device2 shared_data_score = self.score_fusion_memory(node1, node2) + if ( can_reorder and shared_data_score < config.score_fusion_memory_threshold @@ -4702,6 +4848,16 @@ def can_fuse( smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) shared_data_score = self.score_fusion_memory(node1, node2) + if ( + config.loop_index_inversion_in_fusion + and shared_data_score < config.score_fusion_memory_threshold + ): + new_shared_data_score = self.shared_data_after_inverting_indexing( + node1, node2 + ) + if new_shared_data_score >= 0: + shared_data_score = new_shared_data_score + if loop_ordering_log.isEnabledFor(logging.DEBUG): loop_ordering_log.debug( "%s and %s has %s shared data", diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dc4be650eccb..e1d36d54e844 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2145,6 +2145,8 @@ def __init__( # There is no src hash for ExternKernelChoice in the traditional sense # so we indicate this by returning None self.src_hash = None + # By default GraphModule is None for extern kernels if not set + self.gm = None def to_callable(self): return getattr(extern_kernels, self.name) @@ -2317,6 +2319,7 @@ def __init__( self.choice = choice self.kwargs = kwargs or {} self.has_out_variant = has_out_variant + self.gm = choice.gm def __str__(self) -> str: return f"ExternKernelCaller({self.choice.call_name()})" @@ -2700,6 +2703,7 @@ def __call__( precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, best_config_future=None, + return_choice=False, # TODO: return_choice is temporary and will be refactored soon ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2973,18 +2977,25 @@ def get_timings(hint_override: Optional[int] = None): "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", node, ) + if return_choice: + return node, choice return node node = choices[0].output_node() + choice = choices[0] log.debug( "Autotuning returned empty timings, falling back to first choice: %s", node, ) + if return_choice: + return node, choice return node # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() log.debug("Autotuning selected choice: %s", node) + if return_choice: + return node, choice return node def make_precompile_fn( @@ -3719,9 +3730,7 @@ def get_choice_info(choice): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] - out_dict = { - str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] - } + out_dict = {str((M, K, N)): [get_choice_info(choice) for choice in timings]} append_to_log(mm_filename, out_dict) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a27..000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental β€” enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 61616d81c287..8cbbf5073d5e 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1946,6 +1946,29 @@ def _valid(self, kernel_inputs: KernelInputs) -> bool: return False return True + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter out bad configs for specific hardware. + On AMD MI350X (GFX 9.5+), skip configs with BLOCK_K<=64 due to lack of corresponding MFMA instructions. + """ + + def should_skip_mi350x_config(config: BaseConfig) -> bool: + """Skip config if BLOCK_K<=64 on MI350X (GFX 9.5+)""" + try: + return ( + config.block_k <= 64 + and torch.version.hip is not None + and torch.cuda.get_device_capability() >= (9, 5) + ) + except RuntimeError: + # If no HIP GPUs are available, we can't check device capability + # so we don't skip any configs + return False + + filtered_configs = [c for c in configs if not should_skip_mi350x_config(c)] + return super()._filter_configs(filtered_configs) + # Scaled TMA-specific mixin for scaled MM templates with TMA class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 0c9305dc721d..5b394b9ea991 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -165,7 +165,7 @@ def find_coalesced_var( variables[v] = get_hint(v) zero_index = sympy_subs(index, variables) - for v in var_ranges.keys(): + for v in var_ranges: variables[v] = 1 try: new_val = sympy_subs(index, variables) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 13938f6ec1e5..9579dbb3536e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -549,6 +549,70 @@ def is_pointwise_use( return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) +class LogicalConnective(enum.Enum): + OR = enum.auto() + AND = enum.auto() + + +def has_uses( + target: Node, + use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Given a target, explore the uses of `target` by applying `use_selector_fn` + on them, and then aggregate these booleans with the `use_aggregate_type` + logical connective. + + Uses in view ops will follow the views uses. + """ + + def get_use_aggregate_fn( + use_aggregate_type: LogicalConnective, + ) -> Callable[[Iterator[Any]], bool]: + match use_aggregate_type: + case LogicalConnective.AND: + return all + case LogicalConnective.OR: + return any + case _: + return any + + use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type) + + def has_uses_impl(use: Node) -> bool: + if use.op != "call_function": + return False + if not ( + isinstance(use.target, torch._ops.OpOverload) + or use.target is operator.getitem + ): + return False + + target = cast(torch._ops.OpOverload, use.target) + # Process getitem and view + if target is operator.getitem or is_view(target): + return use_aggregate_fn(has_uses_impl(user) for user in use.users) + + return use_selector_fn(target) + + return use_aggregate_fn(has_uses_impl(user) for user in target.users) + + +def has_uses_tagged_as( + target: Node, + use_tags: Collection[torch.Tag], + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Is there a use with given tags? + """ + + return has_uses( + target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type + ) + + def gen_gm_and_inputs( target: Any, args: list[Any], kwargs: dict[str, Any] ) -> tuple[GraphModule, list[torch.Tensor]]: @@ -1911,77 +1975,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL is available - 2. We are on a blackwell arch - 3. The dtype is bf16 - 4. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V @@ -2238,9 +2231,21 @@ def use_cpp_bmm_template( assert isinstance(mat1.layout, Layout) - return ( - use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) - and mat1.layout.is_contiguous() + # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous. + # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. + # So here we specifically check for contiguity within the 2D matrix of each batch. + mat1_size = mat1.layout.size + mat1_stride = mat1.layout.stride + mat1_each_batch_is_contiguous = ( + _use_template_for_cpu(layout) + and mat1.get_dtype() == torch.float32 + and (len(mat1_size) == 3) + and (len(mat1_stride) == 3) + and (mat1_stride[1] == mat1_size[2]) + and (mat1_stride[2] == 1) + ) + return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and ( + mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous ) @@ -2808,13 +2813,16 @@ def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: return type(node) is ir._WaitKernel -def contains_collective(snode: BaseSchedulerNode) -> bool: +def contains_collective( + snode: BaseSchedulerNode, + filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None, +) -> bool: from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode): return any(contains_collective(x) for x in snode.snodes) - return is_collective(snode.node) + return is_collective(snode.node) and (filter_fn is None or filter_fn(snode)) def contains_wait(snode: BaseSchedulerNode) -> bool: diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 62bd70f65a51..cb3cfd1d6029 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -291,7 +291,7 @@ def parse_return(annotation, error_fn): origin = typing.get_origin(annotation) if origin is not tuple: - if annotation not in SUPPORTED_RETURN_TYPES.keys(): + if annotation not in SUPPORTED_RETURN_TYPES: error_fn( f"Return has unsupported type {annotation}. " f"The valid types are: {SUPPORTED_RETURN_TYPES}." diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index f84b77e630bf..fe0492ff19c1 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1021,6 +1021,10 @@ def meta_linalg_eig(input: Tensor): ) values = input.new_empty(input.shape[:-1], dtype=complex_dtype) vectors = input.new_empty(input.shape, dtype=complex_dtype) + is_cuda = device_hint(input) == "cuda" + vectors.as_strided_( + input.shape, make_contiguous_strides_for(input.shape, row_major=is_cuda) + ) return values, vectors diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index a429d28f30cc..134f7617b758 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -248,7 +248,7 @@ def sctype_from_string(s): """Normalize a string value: a type 'name' or a typecode or a width alias.""" if s in _names: return _names[s] - if s in _name_aliases.keys(): + if s in _name_aliases: return _name_aliases[s] if s in _typecodes: return _typecodes[s] diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index f192a39dd029..e3f383675401 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -49,7 +49,7 @@ class Flags: def __init__(self, flag_to_value: dict): - assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + assert all(k in FLAGS for k in flag_to_value) # sanity check self._flag_to_value = flag_to_value def __getattr__(self, attr: str): @@ -59,7 +59,7 @@ def __getattr__(self, attr: str): raise AttributeError(f"No flag attribute '{attr}'") def __getitem__(self, key): - if key in SHORTHAND_TO_FLAGS.keys(): + if key in SHORTHAND_TO_FLAGS: key = SHORTHAND_TO_FLAGS[key] if key in FLAGS: try: @@ -76,7 +76,7 @@ def __setattr__(self, attr, value): super().__setattr__(attr, value) def __setitem__(self, key, value): - if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + if key in FLAGS or key in SHORTHAND_TO_FLAGS: raise NotImplementedError("Modifying flags is not implemented") else: raise KeyError(f"No flag key '{key}'") diff --git a/torch/ao/ns/fx/pattern_utils.py b/torch/ao/ns/fx/pattern_utils.py index c4d231e713b2..d10fdd39da90 100644 --- a/torch/ao/ns/fx/pattern_utils.py +++ b/torch/ao/ns/fx/pattern_utils.py @@ -72,7 +72,7 @@ def get_reversed_fusions() -> list[tuple[NSFusionType, int]]: all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) default_base_op_idx = 0 - for quant_pattern in all_quant_patterns.keys(): + for quant_pattern in all_quant_patterns: # TODO: this is a temporary hack to flatten the patterns from quantization so # that it works with the ns matcher function, maybe we should use `_is_match` # in torch.ao.quantization.fx.match_utils to match the patterns diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 14764c77cc60..59f6a46fe135 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -196,7 +196,7 @@ def prepare(self, model, config): # check that whatever was put into local_args agrees with what was obtained # from tensor_fqn - for key in info_from_tensor_fqn.keys(): + for key in info_from_tensor_fqn: if key in local_args: if not ( info_from_tensor_fqn[key] == local_args[key] diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index a78dd307fc6d..e4ff327f285a 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -270,7 +270,7 @@ def converged(curr_modules, prev_modules, threshold=1e-4): summed_norms = torch.tensor(0.0) if None in prev_modules.values(): return False - for name in curr_modules.keys(): + for name in curr_modules: curr_weight = get_module_weight(curr_modules[name]) prev_weight = get_module_weight(prev_modules[name]) diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index ab44cfa09197..25672e7e6ced 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -678,7 +678,7 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf torch.nn.BatchNorm2d: nni.BNReLU2d, torch.nn.BatchNorm3d: nni.BNReLU3d, } - for bn in bn_to_fused_bn.keys(): + for bn in bn_to_fused_bn: fused_bn = bn_to_fused_bn[bn] # bn module + relu module fusion config bn_configs.append( diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index b8809c1c6087..6c8c32b992ed 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -350,7 +350,7 @@ def get_op_node_and_weight_eq_obs( # Find the op node that comes directly after the input equalization observer op_node = None - for user in input_eq_obs_node.users.keys(): + for user in input_eq_obs_node.users: if node_supports_equalization(user, modules): op_node = user break diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 993a6c41f176..0a48bbbaaee9 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -743,7 +743,7 @@ def generate_detector_report( # Populates the string based report with the information from module_dynamic_static_info # Compiles the complete report by appending relevant formatted strings - for module_fqn in module_dynamic_static_info.keys(): + for module_fqn in module_dynamic_static_info: # there is at least 1 module for suggestion modules_added = True module_info = module_dynamic_static_info[module_fqn] diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 08ae102f69f4..06936e5327bc 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -683,7 +683,7 @@ def _maybe_get_observer_for_node( If the node is observed, return the observer instance. Otherwise, return None. """ - for maybe_obs_node in node.users.keys(): + for maybe_obs_node in node.users: if maybe_obs_node.op == "call_module": maybe_obs = modules[str(maybe_obs_node.target)] if _is_activation_post_process(maybe_obs): diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 0c05e6499901..8351dbedd07d 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -950,7 +950,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # we should remove this # removing this means we insert one observer for each use, even if they # have the same dtype, we can have an extra pass that removes the extra observers - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if maybe_obs_node.op == "call_module": maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] if ( @@ -1440,7 +1440,7 @@ def _maybe_make_input_output_share_observers( setattr(named_modules[parent_name], name, obs_mod_to_use) # set the output observer node to use that module - for output_obs_node in node.users.keys(): + for output_obs_node in node.users: if not _is_activation_post_process_node(output_obs_node, named_modules): raise AssertionError( "output_obs_node must be an activation post process node" diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 74f90505ea2a..951ca66703f4 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -206,7 +206,7 @@ def _check_is_valid_config_dict( `config_dict`: dictionary whose keys we want to check """ - for k in config_dict.keys(): + for k in config_dict: if k not in allowed_keys: raise ValueError( "Expected " @@ -250,7 +250,7 @@ def _compare_prepare_convert_qconfig_mappings( _MODULE_NAME_REGEX_DICT_KEY, ] for i in range(len(prepare_dicts)): - for name in prepare_dicts[i].keys(): + for name in prepare_dicts[i]: if name not in convert_dicts[i]: raise AssertionError( f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 3e2afaaa1d9f..9f76f2a328df 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -442,7 +442,7 @@ def maybe_get_next_module( target_functional_type: Functional type that we want to check """ - for user in node.users.keys(): + for user in node.users: if ( user.op == "call_module" and target_module_type is not None diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index aab4c435c872..8e768592826e 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -66,7 +66,7 @@ def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: continue if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: return n - for k in n.users.keys(): + for k in n.users: queue.append(k) return None diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 6eac69a96ba4..c15e7878eb2b 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -217,7 +217,7 @@ def _get_edge_or_node_to_group_id( # means the observer of key should be shared with observer with value, by default it will # be shared with itself shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { - k: k for k in edge_or_node_to_qspec.keys() + k: k for k in edge_or_node_to_qspec } for edge_or_node, qspec in edge_or_node_to_qspec.items(): if isinstance(edge_or_node, torch.fx.Node): @@ -282,7 +282,7 @@ def _get_edge_or_node_to_group_id( # now that we get the sharing relations between all edges and nodes, we can assign group ids cur_group_id = 0 edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} - for edge_or_node in shared_with_map.keys(): + for edge_or_node in shared_with_map: root = _find_root_edge_or_node(edge_or_node, shared_with_map) if root not in edge_or_node_to_group_id: edge_or_node_to_group_id[root] = cur_group_id @@ -391,7 +391,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # instead of inserting new observers we will have: # conv1 -> obs1 -> existing_obs -> conv2 # \ -> conv3 - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if not _is_activation_post_process_node(maybe_obs_node, named_modules): continue maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 10111d4ab8a2..2bfce5d858cc 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -187,7 +187,7 @@ def _get_default_qconfig_mapping_with_default_qconfig( else: qconfig_mapping = get_default_qconfig_mapping(backend) qconfig_mapping.set_global(default_qconfig) - for pattern in qconfig_mapping.object_type_qconfigs.keys(): + for pattern in qconfig_mapping.object_type_qconfigs: if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: qconfig_mapping.set_object_type(pattern, default_qconfig) return qconfig_mapping diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 79f8db1a792f..ec4caab1edcd 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -68,7 +68,7 @@ def fuse_conv_bn_jit(model, inplace=False): def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): _check_is_script_module(model) _check_forward_method(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) model = fuse_conv_bn_jit(model, inplace) @@ -90,7 +90,7 @@ def _prepare_ondevice_jit( quant_type=QuantType.STATIC, ): _check_is_script_module(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) method_graph = model._c._get_method(method_name).graph diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index b10163d4b1e5..816f48fd6267 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1361,9 +1361,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): elif ( node.target is torch.ops.aten.flatten.using_ints and len(node.users) > 0 - and not any( - user.target in quantizable_ops for user in node.users.keys() - ) + and not any(user.target in quantizable_ops for user in node.users) ): # Recipe of flatten: check if any users of flatten node are quantizable ops or not return diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index fa43af270117..9e2a7b5046de 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -52,26 +52,7 @@ "MemRecordsAcc", ] -try: - # Available in Python >= 3.2 - from contextlib import ContextDecorator as _ContextDecorator -except ImportError: - import functools - - class _ContextDecorator: # type: ignore[no-redef] - def __enter__(self): - raise NotImplementedError - - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError - - def __call__(self, func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapped +from contextlib import ContextDecorator # global python state - whether profiler is currently enabled @@ -744,8 +725,7 @@ def createFunctionEventForMemoryEvents(evt): return all_function_events -# pyrefly: ignore [invalid-inheritance] -class record_function(_ContextDecorator): +class record_function(ContextDecorator): """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. Label will only appear if CPU activity tracing is enabled. diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 9f60295655dd..5dd26c088137 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -296,9 +296,9 @@ def _get_record_key(record): f"Expected CPU and CUDA memory allocation handles to match, " f"but got {num_open_handles_cpu} CPU and {num_open_handles_cuda} CUDA" ) - for handle in cpu_memory_allocs.keys(): + for handle in cpu_memory_allocs: cpu_memory_allocs[handle] += record.cpu_memory_usage() - for handle in cuda_memory_allocs.keys(): + for handle in cuda_memory_allocs: cuda_memory_allocs[handle] += record.cuda_memory_usage() if num_open_handles_cpu == 0: # output event as a top-level memory event diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e6..a61aee321fcf 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines: list[str] = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/compiler/config.py b/torch/compiler/config.py index e7578a57f2c0..e507ddc18052 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -35,6 +35,7 @@ "enable_cpp_symbolic_shape_guards", "wrap_top_frame", "reorderable_logging_functions", + "force_disable_caches", ] diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 42d701298b0d..b3cb07ac1cf9 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -79,6 +79,12 @@ Tensor toNonOptPrimal(const std::optional& t) { return Tensor(); } +void update_wrapped_number(Tensor& input, Tensor& output) { + if (input.unsafeGetTensorImpl()->is_wrapped_number()) { + output.unsafeGetTensorImpl()->set_wrapped_number(true); + } +} + void copy_range(variable_list& out, IndexRange range, const Tensor& t) { TORCH_CHECK(range.second <= out.size()); TORCH_CHECK( diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 4dc0425d426e..ee0f919c4401 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -43,6 +43,7 @@ inline std::optional wrap_opt_if(const Tensor& t, const bool cond) { TORCH_API Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction); TORCH_API bool any_variable_defined(const variable_list& variables); +TORCH_API void update_wrapped_number(Tensor& input, Tensor& output); TORCH_API void copy_range( variable_list& out, IndexRange range, diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 946a8d5f1d36..837ba93d1cc2 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -51,14 +51,101 @@ using namespace at; using namespace torch; using namespace torch::autograd; -std::pair parseIValuesToPyArgsKwargs( - const c10::OperatorHandle& op, - const std::vector& arguments) { - TORCH_CHECK( - PyGILState_Check(), - "GIL must be held before you call parseIValuesToPyArgsKwargs"); - const auto& schema = op.schema(); - py::dict kwargs; +namespace { +class OperatorArgsKwargsView { + public: + OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments); + using args_iterator = const c10::IValue*; + + args_iterator args_begin() const { + return arguments_.data(); + } + + args_iterator args_end() const { + return arguments_.data() + positional_default_start_; + } + + auto num_positional_args() const { + return positional_default_start_; + } + + auto kwarg_start_index() const { + return first_non_default_kwarg_; + } + + struct kwargs_iterator { + kwargs_iterator() = default; + kwargs_iterator(const OperatorArgsKwargsView* parent, size_t current) + : parent_(parent), current_(current) {} + + kwargs_iterator(const kwargs_iterator&) = default; + kwargs_iterator& operator=(const kwargs_iterator&) = default; + + kwargs_iterator& operator++() { + do { + current_++; + } while (current_ < parent_->arguments_.size() && + parent_->is_default(current_)); + return *this; + } + + kwargs_iterator operator++(int) { + auto copy = *this; + ++(*this); + return copy; + } + + const c10::IValue& operator*() const { + return parent_->arguments_[current_]; + } + + const c10::IValue* operator->() const { + return &operator*(); + } + + int64_t underlying_index() const { + return current_; + } + + bool operator==(const kwargs_iterator& rhs) const { + return parent_ == rhs.parent_ && current_ == rhs.current_; + } + + bool operator!=(const kwargs_iterator& rhs) { + return !(*this == rhs); + } + + private: + const OperatorArgsKwargsView* parent_ = nullptr; + size_t current_ = 0; + }; + + kwargs_iterator kwargs_begin() const { + return kwargs_iterator(this, first_non_default_kwarg_); + } + + kwargs_iterator kwargs_end() const { + return kwargs_iterator(this, arguments_.size()); + } + + private: + bool is_default(size_t idx) const { + const auto& arg = op_.schema().arguments()[idx]; + if (!arg.default_value().has_value()) { + return false; + } + const auto& default_ivalue = *arg.default_value(); + const auto& ivalue = arguments_[idx]; + if (default_ivalue != ivalue) { + return false; + } + return true; + } + + const c10::OperatorHandle& op_; + c10::ArrayRef arguments_; // About all the pointers: // // f(int x, int y = 0, *, int z = 0) @@ -66,45 +153,63 @@ std::pair parseIValuesToPyArgsKwargs( // ^- kwarg_only_start // ^- positional_default_start // ^- 0 + int64_t positional_default_start_; + int64_t first_non_default_kwarg_; +}; +OperatorArgsKwargsView::OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments) + : op_(op), arguments_(arguments) { // Find the split point between kwarg-only and regular. Since most functions // don't have kwarg-only arguments, it is more efficient to scan from the // right (but ideally, this would just be precomputed in FunctionSchema // itself). (NB: minus one in the loop is because we're testing if the // *next* argument is kwarg-only before we advance the starting index) - int64_t kwarg_only_start = static_cast(arguments.size()); + const int64_t signed_arguments_size = static_cast(arguments.size()); + int64_t kwarg_only_start = signed_arguments_size; for (; kwarg_only_start > 0; kwarg_only_start--) { - const auto& arg = schema.arguments()[kwarg_only_start - 1]; + const auto& arg = op.schema().arguments()[kwarg_only_start - 1]; if (!arg.kwarg_only()) { break; } } // Find the first positional argument that isn't defaulted - auto is_default = [&](size_t idx) -> bool { - const auto& arg = schema.arguments()[idx]; - if (!arg.default_value().has_value()) { - return false; - } - const auto& default_ivalue = *arg.default_value(); - const auto& ivalue = arguments[idx]; - if (default_ivalue != ivalue) { - return false; + positional_default_start_ = kwarg_only_start; + for (; positional_default_start_ > 0; positional_default_start_--) { + if (!is_default(positional_default_start_ - 1)) { + break; } - return true; - }; + } - int64_t positional_default_start = kwarg_only_start; - for (; positional_default_start > 0; positional_default_start--) { - if (!is_default(positional_default_start - 1)) { + // kwargs_iterator will skip default kwargs when incremented, but we + // need to skip any initial run of default kwargs ourselves. + first_non_default_kwarg_ = kwarg_only_start; + for (; first_non_default_kwarg_ < signed_arguments_size; + ++first_non_default_kwarg_) { + if (!is_default(first_non_default_kwarg_)) { break; } } +} +} // namespace - auto args = - py::reinterpret_steal(PyTuple_New(positional_default_start)); +std::pair parseIValuesToPyArgsKwargs( + const c10::OperatorHandle& op, + const std::vector& arguments) { + TORCH_CHECK( + PyGILState_Check(), + "GIL must be held before you call parseIValuesToPyArgsKwargs"); + const auto& schema = op.schema(); + py::dict kwargs; - auto schemaAwareToPyObject = [&](size_t idx) -> py::object { + OperatorArgsKwargsView args_kwargs(op, arguments); + auto args = py::reinterpret_steal( + PyTuple_New(args_kwargs.num_positional_args())); + + auto schemaAwareToPyObject = + [&schema](size_t idx, const c10::IValue& argument) -> py::object { const auto& arg = schema.arguments()[idx]; auto match = [&](c10::TypeKind kind) { const auto& t = arg.real_type(); @@ -116,38 +221,42 @@ std::pair parseIValuesToPyArgsKwargs( } return false; }; - if (arguments[idx].isNone()) { + if (argument.isNone()) { return py::none(); } else if (match(c10::ScalarTypeType::Kind)) { - auto* obj = - getTHPDtype(static_cast(arguments[idx].toInt())); + auto* obj = getTHPDtype(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::LayoutType::Kind)) { - auto* obj = - getTHPLayout(static_cast(arguments[idx].toInt())); + auto* obj = getTHPLayout(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::MemoryFormatType::Kind)) { - return py::cast(static_cast(arguments[idx].toInt())); + return py::cast(static_cast(argument.toInt())); } else { - return torch::jit::toPyObject(arguments[idx]); + return torch::jit::toPyObject(argument); } }; // Populate positional arguments - for (const auto idx : c10::irange(positional_default_start)) { + size_t idx = 0; + for (auto argument_it = args_kwargs.args_begin(); + argument_it != args_kwargs.args_end(); + ++argument_it) { PyTuple_SET_ITEM( - args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); + args.ptr(), + idx, + schemaAwareToPyObject(idx, *argument_it).release().ptr()); + idx++; } // Populate keyword arguments - for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) { - // But don't populate default keyword arguments - if (is_default(idx)) - continue; - const auto& arg = schema.arguments()[idx]; - kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx); + for (auto argument_it = args_kwargs.kwargs_begin(); + argument_it != args_kwargs.kwargs_end(); + ++argument_it) { + const auto& arg = schema.arguments()[argument_it.underlying_index()]; + kwargs[py::cast(arg.name())] = + schemaAwareToPyObject(argument_it.underlying_index(), *argument_it); } return std::make_pair(std::move(args), std::move(kwargs)); } diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 6ffa1529a4de..72e35e3fc9dd 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } + virtual bool supportsShrinking() const { + return false; + } + + // Shrink the backend by excluding specified ranks. Backends that support + // communicator shrinking should override this and return a new backend + // instance representing the shrunken group. Backends may use opts_override + // to supply backend-specific options for the new group. + virtual c10::intrusive_ptr shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/ = 0, + const c10::intrusive_ptr& /*opts_override*/ = nullptr) { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support shrink")); + } + virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index 23b8893c54f2..bdb4ad045ff2 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -108,12 +108,14 @@ struct FlightRecorder { capture_cpp_stack_ = getCvarBool( {"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false); enabled_ = max_entries_ > 0; + reset_epoch_start_idx_[0] = 0; } struct Entry { size_t id_; // incremented id in the trace buffer // used to figure out where in the circular entries // buffer this entry will be located to // update state information + size_t reset_epoch_; // epoch when this entry was created size_t pg_id_; std::tuple pg_name_; // @@ -183,11 +185,34 @@ struct FlightRecorder { size_t max_entries_ = 0; size_t next_ = 0; size_t id_ = 0; + size_t reset_epoch_ = 0; + std::unordered_map + reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts std::map> all_pg_status_; std::map, std::vector> pg_name_to_ranks_; std::string comm_lib_version_; + struct TraceIdentifier { + std::optional id; + std::optional reset_epoch; + }; + + TraceIdentifier recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P); + std::optional record( size_t pg_id, const std::tuple& pg_name, @@ -213,8 +238,16 @@ struct FlightRecorder { std::vector dump_entries(); - // Returns the entry with the given id, if it exists. Otherwise, returns - // std::nullopt. + // Returns the index in entries_ for the given id and reset_epoch. + // Caller must hold mutex_lock before calling this method. + size_t getIdxFromId(size_t id, size_t reset_epoch) const; + + // Returns the entry with the given id and reset_epoch, if it exists. + // Otherwise, returns std::nullopt. + TORCH_API std::optional getEntry( + std::optional id, + std::optional reset_epoch); + TORCH_API std::optional getEntry(std::optional id); /* @@ -227,6 +260,11 @@ struct FlightRecorder { never hang. (timing must also be enabled for compute_duration - see TORCH_NCCL_ENABLE_TIMING). */ + TORCH_API void retire_id( + std::optional id, + std::optional reset_epoch, + bool compute_duration = true); + TORCH_API void retire_id( std::optional id, bool compute_duration = true); diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp index 8813c9515846..88205c171941 100644 --- a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -53,8 +53,41 @@ std::optional FlightRecorder::record( std::chrono::milliseconds timeout_ms, std::shared_ptr pg_status, bool isP2P) { + auto result = recordWithResetEnabled( + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + inputs, + outputs, + start, + end, + timeout_ms, + std::move(pg_status), + isP2P); + return result.id; +} + +template +typename FlightRecorder::TraceIdentifier FlightRecorder:: + recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P) { if (!enabled_) { - return std::nullopt; + return TraceIdentifier{std::nullopt, std::nullopt}; } if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { // Current pg_status is not in FR. @@ -64,8 +97,13 @@ std::optional FlightRecorder::record( torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); std::lock_guard guard(mutex_); + TORCH_CHECK( + reset_epoch_start_idx_.find(reset_epoch_) != + reset_epoch_start_idx_.end()); + auto te = Entry{ id_, + reset_epoch_, pg_id, pg_name, collective_seq_id, @@ -104,15 +142,20 @@ std::optional FlightRecorder::record( te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); } + const auto next = next_++; + if (entries_.size() < max_entries_) { entries_.emplace_back(std::move(te)); } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } + entries_[next] = std::move(te); } - return id_++; + + if (next_ == max_entries_) { + next_ = 0; + } + + const auto id = id_++; + return TraceIdentifier{id, reset_epoch_}; } template @@ -163,15 +206,20 @@ std::vector::Entry> FlightRecorder< std::vector result; { std::lock_guard guard(mutex_); - result.reserve(entries_.size()); - result.insert( - result.end(), + // Filter entries during insertion - only keep entries from current epoch + auto filter = [this](const Entry& e) { + return e.reset_epoch_ == reset_epoch_; + }; + std::copy_if( entries_.begin() + static_cast(next_), - entries_.end()); - result.insert( - result.end(), + entries_.end(), + std::back_inserter(result), + filter); + std::copy_if( entries_.begin(), - entries_.begin() + static_cast(next_)); + entries_.begin() + static_cast(next_), + std::back_inserter(result), + filter); } // query any remaining events for (auto& r : result) { @@ -182,28 +230,47 @@ std::vector::Entry> FlightRecorder< } template -// Returns the entry with the given id, if it exists. Otherwise, returns -// std::nullopt. +// Returns the index in entries_ for the given id and reset_epoch. +// Caller must hold mutex_lock before calling this method. +size_t FlightRecorder::getIdxFromId(size_t id, size_t reset_epoch) + const { + // Look up the starting idx for the given reset epoch + auto it = reset_epoch_start_idx_.find(reset_epoch); + TORCH_CHECK(it != reset_epoch_start_idx_.end()); + // Calculate idx based on where the epoch started + return (it->second + id) % max_entries_; +} + +template +// Returns the entry with the given id and reset_epoch, if it exists. Otherwise, +// returns std::nullopt. std::optional::Entry> FlightRecorder< - EventType>::getEntry(std::optional id) { - if (!enabled_ || !id) { + EventType>:: + getEntry(std::optional id, std::optional reset_epoch) { + if (!enabled_ || !id || !reset_epoch) { return std::nullopt; } std::unique_lock guard(mutex_); - Entry entry = entries_.at(*id % max_entries_); - if (entry.id_ == *id) { + Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) { return entry; - } else { - return std::nullopt; } + return std::nullopt; +} + +template +std::optional::Entry> FlightRecorder< + EventType>::getEntry(std::optional id) { + return getEntry(id, 0); } template void FlightRecorder::retire_id( std::optional id, + std::optional reset_epoch, bool compute_duration) { - if (!enabled_ || !id) { + if (!enabled_ || !id || !reset_epoch) { return; } @@ -214,8 +281,8 @@ void FlightRecorder::retire_id( std::unique_lock guard(mutex_); - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { + Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) { update_state(*entry); if (compute_duration) { @@ -237,8 +304,8 @@ void FlightRecorder::retire_id( guard.lock(); // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { + entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) { LOG(INFO) << "retire_id abandoned for id " << *id << ", event was overwritten while waiting to compute duration."; return; @@ -249,12 +316,23 @@ void FlightRecorder::retire_id( } } +template +void FlightRecorder::retire_id( + std::optional id, + bool compute_duration) { + retire_id(id, 0, compute_duration); +} + template void FlightRecorder::reset_all() { std::lock_guard guard(mutex_); - next_ = 0; - id_ = 0; - entries_.clear(); + if (!entries_.empty()) { + // Soft delete: increment epoch to mark all existing entries as old + // Store where the new epoch starts in the circular buffer + reset_epoch_++; + reset_epoch_start_idx_[reset_epoch_] = next_; + id_ = 0; + } } template diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 8074cc98a04f..a41f654b9ae2 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -259,6 +259,65 @@ std::shared_ptr NCCLComm::split( } #endif +#ifdef NCCL_HAS_COMM_SHRINK +std::shared_ptr NCCLComm::shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags) { + // Preconditions are validated in ProcessGroupNCCL::shrink + + LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr() + << " excluding " << ranks_to_exclude.size() << " ranks"; + + at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); + auto comm = std::make_shared(); + + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); + + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommShrink( + sourceComm, + ranks_to_exclude.data(), + ranks_to_exclude.size(), + reinterpret_cast(&(comm->ncclComm_)), + config, + shrinkFlags), + source->getNcclCommFailureReason()); + + // Wait for the child communicator to be ready + source->waitReady(true); + comm->initialized_ = true; + + // NCCL automatically assigns rank during shrink - query it efficiently + int assigned_rank; + try { + C10D_NCCL_CHECK( + ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt); + comm->rank_ = assigned_rank; + } catch (const std::exception& e) { + // Fallback: if ncclCommUserRank fails, we can't determine the rank + LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what(); + throw; + } + + // Child comm should be on the same device as parent comm + comm->deviceIndex_ = source->deviceIndex_; + if (config != nullptr) { + comm->nonBlocking_ = config->blocking == 0; + } else { + // Inherit parent behavior if no config provided + comm->nonBlocking_ = source->nonBlocking_; + } + + LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm " + << comm->repr() << " with NCCL-assigned rank " << assigned_rank; + + return comm; +} +#endif + void NCCLComm::finalize() { LockType lock(mutex_); if (aborted_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index fdd50f69ef3d..142633b82374 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -90,6 +90,10 @@ static_assert( #define NCCL_HAS_NVLS_CTAS #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_SHRINK +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -294,6 +298,14 @@ class NCCLComm { ncclConfig_t& config); #endif // NCCL_HAS_COMM_SPLIT +#ifdef NCCL_HAS_COMM_SHRINK + static std::shared_ptr shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags = 0); +#endif // NCCL_HAS_COMM_SHRINK + #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index a9612ce75973..c1d28b2787cd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) { // TODO: We need to have numel of tensors for gloo as well. pgStatus_->lastCompletedNumelIn = 0; pgStatus_->lastCompletedNumelOut = 0; - FlightRecorder::get()->retire_id(work->trace_id_, false); + FlightRecorder::get()->retire_id( + work->trace_id_, work->trace_reset_epoch_, false); lock.lock(); workInProgress_[workerIndex].reset(); } @@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { pgStatus_->lastEnqueuedNumelOut = 0; // using c10d::FlightRecorder; // TODO: We need to have a way to use c10::Event inside gloo as well. - work->trace_id_ = FlightRecorder::get()->record( + auto traceId = FlightRecorder::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), collectiveCounter_, @@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { work->getTimeout(), pgStatus_, false); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; workQueue_.push_back(std::move(work)); lock.unlock(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index b2cc6993528b..1a0b7c41b385 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; std::shared_ptr context_; const std::chrono::milliseconds timeout_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fd7f0b424651..29ccc115cc94 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp( } // Get a key string from device -inline std::string getKeyFromDevice(at::Device& device) { +inline std::string getKeyFromDevice(const at::Device& device) { return std::to_string(device.index()); } @@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), + trace_reset_epoch_(w.trace_reset_epoch_), distDebugLevel_(w.distDebugLevel_) { exception_ = w.exception_; } @@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( // Print the traceback of the collective at call time std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const { // First step we get the corresponding record entry from FR, based on work's - // trace_id_ + // trace_id_ and trace_reset_epoch_ std::optional entry = - FlightRecorderCUDA::get()->getEntry(trace_id_); + FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_); if (entry.has_value()) { auto entryVal = entry.value(); // Get stack trace from FR entry, in string format @@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_; pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_; - FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); + FlightRecorderCUDA::get()->retire_id( + work.trace_id_, work.trace_reset_epoch_, true); if (pg_->onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook // thread @@ -3360,7 +3362,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( // these objects to the Work because it has implications for keeping those // tensors alive longer and adds overhead when copying Work objects // between threads - r->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -3374,6 +3376,8 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( options_->timeout, pgStatus_, isP2P); + r->trace_id_ = traceId.id; + r->trace_reset_epoch_ = traceId.reset_epoch; } return r; } @@ -3593,6 +3597,7 @@ float ProcessGroupNCCL::endTimeEstimate() { #ifdef NCCL_SIM_INFO_INITIALIZER ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER; C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt); + --ncclActiveGroupCounter_; return simInfo.estimatedTime; #else TORCH_CHECK( @@ -3676,7 +3681,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // later in endCoalescing we record a 'coalesced' Work which has // timing/state updates via watchdog thread, but lacks op metadata such as // input/output sizes and profilingTitle per-op in the group. - FlightRecorderCUDA::get()->record( + FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4168,7 +4173,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // TODO(whc) because we don't pass output {tensor} to initWork, we tell // initWork to not record, and then we manually call record passing all the // information it wants. - work->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4182,6 +4187,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( options_->timeout, pgStatus_, /*isP2P=*/true); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; } // Only check for NaN for send ops, for recv ops `tensor` can be a random @@ -5842,6 +5849,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor( return tensor; } +#ifdef NCCL_HAS_COMM_SHRINK +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& ranks_to_exclude, + int shrink_flags, + const c10::intrusive_ptr& opts_override) { + // Runtime version check with better error message + auto runtime_version = torch::cuda::nccl::version(); + TORCH_CHECK( + runtime_version >= NCCL_VERSION(2, 27, 0), + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. " + "Found version: ", + runtime_version); + + // Early validation with detailed error messages + TORCH_CHECK_VALUE( + !ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty"); + TORCH_CHECK_VALUE( + static_cast(ranks_to_exclude.size()) < size_, + "Cannot exclude all ranks (", + ranks_to_exclude.size(), + " >= ", + size_, + ")"); + + // Validate ranks and convert to int efficiently + std::vector int_ranks_to_exclude; + int_ranks_to_exclude.reserve(ranks_to_exclude.size()); + for (int64_t rank : ranks_to_exclude) { + TORCH_CHECK_VALUE( + rank >= 0 && rank < size_, + "Invalid rank ", + rank, + " for group size ", + size_); + int_ranks_to_exclude.push_back(static_cast(rank)); + } + + // Get primary communicator with better error context + auto primary_device_index = guessDeviceId(); + auto primary_device = at::Device(at::kCUDA, primary_device_index); + const auto primary_key = getKeyFromDevice(primary_device); + + std::shared_ptr primary_comm = getNCCLComm(primary_key); + TORCH_CHECK( + primary_comm, + "Primary NCCL communicator for device ", + primary_device, + " (key: ", + primary_key, + ") is not initialized"); + + // Cache device index before shrink operation + at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex(); + + ncclConfig_t* config = nullptr; + // Default to inheriting from parent options + bool high_priority_stream = options_->is_high_priority_stream; + if (opts_override) { + auto nccl_opts = + c10::static_intrusive_pointer_cast( + opts_override); + config = &nccl_opts->config; + // If user provided override options, honor is_high_priority_stream as well + high_priority_stream = nccl_opts->is_high_priority_stream; + } + + std::shared_ptr shrunk_comm = NCCLComm::shrink( + primary_comm.get(), + int_ranks_to_exclude, + (config != nullptr ? config : &options_->config), + shrink_flags); + + // Calculate new size and get NCCL-assigned rank + int new_size = size_ - static_cast(ranks_to_exclude.size()); + int new_rank = shrunk_comm->rank_; + + // Create new ProcessGroupNCCL with optimized options cloning + auto new_store = store_->clone(); + auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream); + new_opts->timeout = options_->timeout; + if (config != nullptr) { + new_opts->config = *config; + } else { + new_opts->config = options_->config; + } + + auto new_pg = c10::make_intrusive( + new_store, new_rank, new_size, new_opts); + + // Set up the new process group with optimized device setup + new_pg->initializeDeviceStateForComm( + at::Device(at::kCUDA, parent_device_index), shrunk_comm); + + return c10::static_intrusive_pointer_cast(new_pg); +} + +#else // !NCCL_HAS_COMM_SHRINK +// Backend interface override: raise consistent error when shrink is +// unsupported. +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/, + const c10::intrusive_ptr& /*opts_override*/) { + TORCH_CHECK( + false, + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, " + "but PyTorch was built with an older version or without NCCL shrink support."); +} + +#endif // NCCL_HAS_COMM_SHRINK + +void ProcessGroupNCCL::initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm) { + const auto key = getKeyFromDevice(device); + std::unique_lock lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto stream = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + devNCCLCommMap_[key] = comm; + ncclStreams_.emplace(key, stream); + ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming)); + usedDeviceIdxs_.insert(device.index()); + + if (shouldAllCommunicatorsRegisterAllTensors()) { + std::lock_guard map_lock(ncclCommMemPoolMapMutex); + ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{}); + } +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 286eab14d1a8..d8f324dbd8ed 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; DebugLevel distDebugLevel_; friend class ProcessGroupNCCL; }; @@ -997,6 +998,21 @@ class TORCH_API ProcessGroupNCCL : public Backend { ErrorType getError() override; + bool supportsShrinking() const override { +#ifdef NCCL_HAS_COMM_SHRINK + return true; +#else + return false; +#endif + } + + // Backend-style shrink override that returns a Backend instance. + c10::intrusive_ptr shrink( + const std::vector& ranks_to_exclude, + int shrink_flags = 0, + const c10::intrusive_ptr& opts_override = + nullptr) override; + std::shared_ptr getMemAllocator() override; // Allocate tensor from communication-optimized memory pool @@ -1065,6 +1081,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { int p2pRank = 0, bool isSendRecvSelf = false); + // Initialize device-specific state (comm, stream, event, bookkeeping) for a + // given communicator on this process group instance. + void initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm); + // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a6c6c6f8c474..91bb3469e3e8 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #ifdef USE_C10D_GLOO #include @@ -2734,12 +2735,23 @@ The hook must have the following signature: "supports_time_estimate", &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") + .def_property_readonly( + "supports_shrinking", + &::c10d::Backend::supportsShrinking, + "(test whether the backend supports communicator shrinking)") .def( "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") + .def( + "shrink", + &::c10d::Backend::shrink, + py::arg("ranks_to_exclude"), + py::arg("shrink_flags") = 0, + py::arg("opts_override") = nullptr, + py::call_guard()) .def( "broadcast", &::c10d::Backend::broadcast, @@ -3876,6 +3888,33 @@ such as `dist.all_reduce(tensor, async_op=True)`. .def("wait", &::c10d::FakeWork::wait, py::arg("timeout") = kNoTimeout) .def("getFuture", &::c10d::FakeWork::getFuture); + auto pythonCallbackWork = + intrusive_ptr_no_gil_destructor_class_<::c10d::PythonCallbackWork>( + module, "PythonCallbackWork", work) + .def(py::init(), py::arg("callback")) + .def( + "wait", + &::c10d::PythonCallbackWork::wait, + py::arg("timeout") = kNoTimeout, + R"( + Waits until the callback completes. Blocking operation. + The callback is invoked with the timeout parameter and should return a boolean. + Throws if the callback completes with an exception. + Returns the boolean value returned by the callback. + )") + .def( + "get_future", + [](::c10d::PythonCallbackWork& work) + -> std::shared_ptr { + return std::make_shared( + work.getFuture()); + }, + R"( + Returns: + A ``torch.futures.Future`` object which is associated with the completion of + the ``PythonCallbackWork``. + )"); + py::class_(module, "DDPLoggingData") .def(py::init<>()) .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map) diff --git a/torch/csrc/distributed/c10d/python_callback_work.cpp b/torch/csrc/distributed/c10d/python_callback_work.cpp new file mode 100644 index 000000000000..47bef1831a48 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.cpp @@ -0,0 +1,64 @@ +#include +#include + +namespace c10d { + +PythonCallbackWork::PythonCallbackWork(py::function callback) + : callback_(std::move(callback)) { + // Create a future that will be marked as complete when wait() is called + future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); +} + +// NOLINTNEXTLINE(bugprone-exception-escape) +PythonCallbackWork::~PythonCallbackWork() { + py::gil_scoped_acquire ag; + callback_.dec_ref(); + // Explicitly set callback_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + callback_.ptr() = nullptr; +} + +bool PythonCallbackWork::wait(std::chrono::milliseconds timeout) { + py::gil_scoped_acquire ag; + + try { + // Call the Python callback with timeout + py::object result = callback_(timeout); + + // Extract the boolean result + bool success = result.cast(); + + // Mark the work as completed if successful + if (success) { + finish(); + // Mark the future as complete with an empty list + if (!future_->completed()) { + future_->markCompleted(c10::IValue(c10::List())); + } + } + + return success; + } catch (py::error_already_set& e) { + // Capture the Python exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } catch (const std::exception& e) { + // Capture any C++ exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } +} + +c10::intrusive_ptr PythonCallbackWork::getFuture() { + return future_; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_callback_work.hpp b/torch/csrc/distributed/c10d/python_callback_work.hpp new file mode 100644 index 000000000000..48966e785ad6 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace c10d { + +// PythonCallbackWork is a subclass of Work that wraps a Python callback +// function that implements wait(). This allows asynchronous work to +// be integrated with Python code, enabling custom completion logic or +// post-processing in Python. +class PythonCallbackWork : public Work { + public: + explicit PythonCallbackWork(py::function callback); + + ~PythonCallbackWork() override; + + bool wait(std::chrono::milliseconds timeout) override; + + c10::intrusive_ptr getFuture() override; + + private: + py::function callback_; + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index b9dccb456fd6..8dc316b98e63 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -13,6 +13,11 @@ #define _PyCode_SetExtra PyUnstable_Code_SetExtra #endif +namespace { +// Short-term fix for: https://github.com/pytorch/pytorch/issues/166926 +bool use_lru = true; +} // namespace + Py_ssize_t extra_index = -1; CacheEntry* ExtraState::get_first_entry() { @@ -190,7 +195,9 @@ void lookup( ++index; } if (found) { - extra_state->move_to_front(found); + if (use_lru) { + extra_state->move_to_front(found); + } *maybe_cached_code = found->code.ptr(); *trace_annotation = found->trace_annotation.c_str(); return; @@ -202,8 +209,14 @@ CacheEntry* create_cache_entry( ExtraState* extra_state, PyObject* guarded_code, PyObject* backend) { - extra_state->cache_entry_list.emplace_front(guarded_code, backend); - auto new_iter = extra_state->cache_entry_list.begin(); + std::list::iterator new_iter; + if (use_lru) { + extra_state->cache_entry_list.emplace_front(guarded_code, backend); + new_iter = extra_state->cache_entry_list.begin(); + } else { + extra_state->cache_entry_list.emplace_back(guarded_code, backend); + new_iter = std::prev(extra_state->cache_entry_list.end()); + } new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; // Set guard_manager references to extra_state and CacheEntry @@ -269,6 +282,14 @@ void _load_precompile_entry( extra->precompile_entries.push_back(std::move(entry)); } +void _set_lru_cache(py::object boolean) { + if (py::cast(boolean)) { + use_lru = true; + } else { + use_lru = false; + } +} + py::list _debug_get_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 1630ac90b21d..bc62e93bf3f1 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -203,5 +203,6 @@ void _load_precompile_entry( py::object guard_manager, py::object dynamo_code); py::list _debug_get_precompile_entries(const py::handle& code_obj); +void _set_lru_cache(py::object boolean); #endif diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index f1590e19d49c..790ff9acff3a 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -254,6 +254,7 @@ void initDynamoBindings(PyObject* torch) { m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries); + m.def("_set_lru_cache", &_set_lru_cache); py::bind_vector>(m, "VectorUInt8"); init_THPCaches(); if (THP_PyOpcode_Caches != nullptr) { diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index d60a6a099008..9f7c2756d0d7 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -587,7 +587,9 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isTensor()) { auto tensor = std::move(ivalue).toTensor(); if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { - TORCH_INTERNAL_ASSERT(tensor.device().is_cpu()); + TORCH_INTERNAL_ASSERT( + tensor.device().is_cpu() || + (tensor._is_zerotensor() && tensor.dim() == 0)); auto py_tensor = py::cast(tensor); if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) { return py_tensor.attr("_wrapped_number"); @@ -595,17 +597,27 @@ py::object toPyObject(IValue ivalue) { auto scalar_type = tensor.scalar_type(); switch (scalar_type) { case at::ScalarType::Bool: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(false) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(int64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::UInt64: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(uint64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Double: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(0.0) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: // TODO: https://github.com/pytorch/pytorch/issues/77134 - return py::cast(static_cast>( - *tensor.const_data_ptr>())); + return (tensor._is_zerotensor()) + ? py::cast(std::complex(0.0, 0.0)) + : py::cast(static_cast>( + *tensor.const_data_ptr>())); default: TORCH_CHECK( false, diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 5c2959e69ae0..d5fbba9fbbfd 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -5,13 +5,13 @@ #include #include #include -#include #include #include #include #include #include +#include HIDDEN_NAMESPACE_BEGIN(torch, stable) @@ -68,7 +68,7 @@ inline torch::stable::Tensor narrow( // only dtype information. inline torch::stable::Tensor new_empty( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -107,7 +107,7 @@ inline torch::stable::Tensor new_empty( // only dtype information. inline torch::stable::Tensor new_zeros( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -144,12 +144,10 @@ inline torch::stable::Tensor new_zeros( // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as -// use std::vector because -// (1) IntArrayRef is not yet header-only -// (2) SymInt is not yet header-only +// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only. inline torch::stable::Tensor pad( const torch::stable::Tensor& self, - std::vector pad, + torch::headeronly::IntHeaderOnlyArrayRef pad, const std::string& mode = "constant", double value = 0.0) { AtenTensorHandle ret0 = nullptr; @@ -181,11 +179,10 @@ inline torch::stable::Tensor amax( // This function is an overload to compute the maximum value along each slice of // `self` reducing over all the dimensions in the vector `dims`. The // amax.default op takes in a SymInt[] as the dims argument, however dims is -// typed as use std::vector here because (1) IntArrayRef is not yet -// header-only (2) SymInt is not yet header-only +// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only inline torch::stable::Tensor amax( const torch::stable::Tensor& self, - std::vector dims, + torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim = false) { AtenTensorHandle ret = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax( diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index f35ed50d99be..8004e91b77f8 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -31,10 +31,8 @@ template struct FromImpl { static StableIValue call( T val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -75,10 +73,8 @@ template <> struct FromImpl { static StableIValue call( ScalarType val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { switch (val) { case ScalarType::Byte: return from(aoti_torch_dtype_uint8()); @@ -133,10 +129,8 @@ template <> struct FromImpl { static StableIValue call( std::nullopt_t val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return from(nullptr); } }; @@ -190,10 +184,8 @@ template <> struct FromImpl { static StableIValue call( const torch::stable::Tensor& val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { AtenTensorHandle new_ath; TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); return from(new_ath); @@ -209,10 +201,8 @@ template struct ToImpl { static T call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the @@ -249,10 +239,8 @@ template <> struct ToImpl { static ScalarType call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { int32_t shim_scalartype = to(val); if (shim_scalartype == aoti_torch_dtype_uint8()) { return ScalarType::Byte; @@ -309,10 +297,8 @@ template <> struct ToImpl { static std::nullopt_t call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { // val should be equivalent to from(nullptr) return std::nullopt; } @@ -350,10 +336,8 @@ template <> struct ToImpl { static torch::stable::Tensor call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return torch::stable::Tensor(to(val)); } }; diff --git a/torch/csrc/stable/tensor_struct.h b/torch/csrc/stable/tensor_struct.h index 88cc167e5977..0d44ffd07517 100644 --- a/torch/csrc/stable/tensor_struct.h +++ b/torch/csrc/stable/tensor_struct.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ HIDDEN_NAMESPACE_BEGIN(torch, stable) using accelerator::DeviceIndex; +using torch::headeronly::IntHeaderOnlyArrayRef; using torch::headeronly::ScalarType; // The torch::stable::Tensor class is a highlevel C++ wrapper around @@ -93,6 +95,32 @@ class Tensor { return numel; } + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the dimension sizes of + // a Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef sizes() const { + int64_t* sizes; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes)); + return IntHeaderOnlyArrayRef(sizes, dim()); + } + + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the strides of a + // Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef strides() const { + int64_t* strides; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides)); + return IntHeaderOnlyArrayRef(strides, dim()); + } + // note: this is a subset of the original TensorBase API. It takes no // arguments whereas the original API takes in a kwarg of memory format. // Here, we assume the default contiguous memory format. diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index dff869742df5..23d297b6d95e 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None): "nvidia-ml-py does not seem to be installed or it can't be imported." # pyrefly: ignore [invalid-inheritance] ) from _PYNVML_ERR - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: diff --git a/torch/cuda/_device_limits.py b/torch/cuda/_device_limits.py index 808d748c8f6e..60aeedc8053a 100644 --- a/torch/cuda/_device_limits.py +++ b/torch/cuda/_device_limits.py @@ -53,7 +53,7 @@ def get_fma_per_cycle_per_sm_cuda_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) @@ -96,7 +96,7 @@ def get_fma_per_cycle_per_sm_tensor_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 2dfd5f947949..a1decc20cc9a 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -4,12 +4,14 @@ import collections import contextlib import ctypes +import os import pickle +import re import sys import warnings from inspect import signature -from typing import Any, Literal, Optional, TYPE_CHECKING -from typing_extensions import deprecated +from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict +from typing_extensions import deprecated, NotRequired import torch from torch import _C @@ -29,6 +31,60 @@ from torch.types import Device +# Type definitions for memory profiler +class _Frame(TypedDict): + """Frame information from memory profiler snapshots.""" + + filename: str + line: int + name: str + # Fields added by FX augmentation (optional) + fx_node_op: NotRequired[str] + fx_node_name: NotRequired[str] + fx_node_target: NotRequired[str] + fx_original_trace: NotRequired[str] + + +class _Block(TypedDict): + """Memory block information.""" + + size: int + requested_size: int + address: int + state: str + frames: list[_Frame] + + +class _Segment(TypedDict): + """Memory segment information.""" + + address: int + total_size: int + stream: int + segment_type: str + allocated_size: int + active_size: int + blocks: list[_Block] + + +class _TraceEntry(TypedDict): + """Memory trace entry information.""" + + action: str + addr: NotRequired[int] + frames: list[_Frame] + size: int + stream: int + device_free: NotRequired[int] + + +class _Snapshot(TypedDict): + """Memory snapshot structure.""" + + segments: list[_Segment] + device_traces: NotRequired[list[list[_TraceEntry]]] + + __all__ = [ "caching_allocator_alloc", "caching_allocator_delete", @@ -772,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str: import pynvml # type: ignore[import] except ModuleNotFoundError: return "pynvml module not found, please install nvidia-ml-py" - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: @@ -964,7 +1020,120 @@ def _record_memory_history_impl( _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] -def _snapshot(device: "Device" = None): +def _augment_frames(frames: list[_Frame]) -> int: + """ + Augment a list of frames with FX debug information. + + Args: + frames: List of frame dictionaries to augment + + Returns: + The count of frames that were augmented. + """ + from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX + + # Regex pattern to match FX generated files + _FX_GENERATED_PATTERN = re.compile( + rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$" + ) + + count = 0 + if not frames: + return count + + for frame in frames: + if "filename" in frame and "line" in frame: + filename = frame["filename"] + lineno = frame["line"] + + # Check if this looks like an FX generated file + if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)): + continue + + # Look up metadata from the global registry + from torch.fx.traceback import _FX_METADATA_REGISTRY + + metadata = _FX_METADATA_REGISTRY.get(filename) + if metadata is None: + continue + + lineno_map = metadata.get("lineno_map", {}) + node_metadata = metadata.get("node_metadata", {}) + prologue_start = metadata.get("prologue_start", 0) + + # Get the node index for this line + node_idx = lineno_map.get(lineno - prologue_start) + + if node_idx is not None and node_idx in node_metadata: + node_info = node_metadata[node_idx] + original_trace = node_info.get("stack_trace") + node_op = node_info.get("op") + node_name = node_info.get("name") + node_target = node_info.get("target") + + # Always add node metadata + frame["fx_node_op"] = node_op + frame["fx_node_name"] = node_name + frame["fx_node_target"] = str(node_target) + + # Add original trace if available + if original_trace: + frame["fx_original_trace"] = original_trace + + count += 1 + + return count + + +def _augment_memory_snapshot_stack_traces( + snapshot: str | _Snapshot, +) -> _Snapshot: + """ + Augment a memory snapshot with original source stack traces from FX metadata. + + IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY) + that is populated during graph module compilation. It must be called in the same + Python process where the FX graphs were compiled. It cannot be used to augment + snapshots loaded from disk in a different process. + + Args: + snapshot: Either a memory snapshot dict or path to a snapshot pickle file + + Returns: + The augmented snapshot dictionary with fx_node_op, fx_node_name, + fx_original_trace, and fx_node_info fields added to frames + """ + + snapshot_dict: _Snapshot + if isinstance(snapshot, str): + # Load the memory snapshot + with open(snapshot, "rb") as f: + snapshot_dict = cast(_Snapshot, pickle.load(f)) + else: + snapshot_dict = snapshot + + # Process stack traces in the snapshot + augmented_count = 0 + + # Process blocks in segments (for regular allocations) + if "segments" in snapshot_dict: + for segment in snapshot_dict["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + augmented_count += _augment_frames(block["frames"]) + + # Process device traces (for memory history) + if "device_traces" in snapshot_dict: + for trace_list in snapshot_dict["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + augmented_count += _augment_frames(trace_entry["frames"]) + + return snapshot_dict + + +def _snapshot(device: "Device" = None, augment_with_fx_traces=False): """Save a snapshot of CUDA memory state at the time it was called. The state is represented as a dictionary with the following structure. @@ -1012,6 +1181,11 @@ class Frame(TypedDict): filename: str line: int name: str + # Optional FX debug fields (present when augment_with_fx_traces=True + # and the frame corresponds to FX-generated code) + fx_node_op: str # FX node operation type (e.g., 'call_function', 'output') + fx_node_name: str # FX node name (e.g., 'linear', 'relu_1') + fx_original_trace: str # Original model source code stack trace class TraceEntry(TypedDict): @@ -1041,13 +1215,23 @@ class TraceEntry(TypedDict): device_free: int # only present for OOM, the amount of # memory cuda still reports to be free + Args: + device: Device to capture snapshot for. If None, captures for current device. + augment_with_fx_traces: If True, augment stack trace frames with FX debug information + that maps generated FX code back to original model source code. + This adds fx_node_op, fx_node_name, fx_original_trace, and + fx_node_info fields to Frame objects. Default: False. + Returns: The Snapshot dictionary object """ - return _C._cuda_memorySnapshot(None) + s = _C._cuda_memorySnapshot(None) + if augment_with_fx_traces: + s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type] + return s -def _dump_snapshot(filename="dump_snapshot.pickle"): +def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False): """ Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. @@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"): Args: filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information + before dumping. This maps generated FX code stack traces + back to original model source code. Defaults to False. + verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output + during augmentation. Defaults to False. """ - s = _snapshot() + s = _snapshot(augment_with_fx_traces=augment_with_fx_traces) + with open(filename, "wb") as f: pickle.dump(s, f) diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index ea9707b2e1e8..c186694df94e 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -64,6 +64,7 @@ np = None # type: ignore[assignment] import torch +import torch.distributed as dist from torch import Size, SymBool, SymInt, Tensor from torch._C import DispatchKey, DispatchKeySet, ScriptObject from torch._export.wrappers import mark_subclass_constructor_exportable_experimental @@ -921,6 +922,22 @@ def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor({r: cb(r) for r in self.ranks}) + def tensor_map( + self, tensor: LocalTensor, cb: Callable[[int, Tensor], Tensor | None] + ) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + results = {} + for r in self.ranks: + if r in tensor._local_tensors: + m = cb(r, tensor._local_tensors[r]) + if m is not None: + results[r] = m + return LocalTensor(results) + def _patch_device_mesh(self) -> None: assert self._old_get_coordinate is None self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] @@ -1049,3 +1066,120 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: """ lm = local_tensor_mode() return lm.disable() if lm is not None else contextlib.nullcontext() + + +import threading +from queue import Queue + + +_LOCAL_RUNNER_MODE: "LocalRunnerMode | None" = None + + +class LocalRunnerMode: + """ + A class for running multiple SPMD functions concurrently, however at any point + in time only one function can be running. The main use case for the local runner + mode is to enable SPMD functions to be able to use send and recv to communicate + with each other. Without local runner mode send and recv are not supported. + """ + + runner_context = threading.local() + + def __init__( + self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None] + ): + if isinstance(ranks, int): + ranks = frozenset(range(ranks)) + self._ranks = ranks + self._fn = fn + self._run_lock = threading.Lock() + self._run_id = -1 + self._run_cond = threading.Condition(self._run_lock) + + self._recv_objects: dict[int, dict[int, Queue]] = { + dst: {src: Queue() for src in ranks} for dst in ranks + } + self._runners = [ + threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") + for i in range(concurrency) + ] + + def __enter__(self) -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" + _LOCAL_RUNNER_MODE = self + + for r in self._runners: + r.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + for r in self._runners: + r.join() + global _LOCAL_RUNNER_MODE + _LOCAL_RUNNER_MODE = None + + def _run(self, id: int) -> None: + LocalRunnerMode.runner_context.id = id + # Only one thread can run at a time, hence must acquire the lock + try: + self._acquire_run_lock() + self._fn(id) + finally: + self._release_run_lock() + + def _acquire_run_lock(self) -> None: + self._run_lock.acquire() + self._run_id = LocalRunnerMode.runner_context.id + + def _release_run_lock(self) -> None: + self._run_id = -1 + self._run_lock.release() + + def _assert_holds_run_lock(self) -> None: + assert self._run_id == LocalRunnerMode.runner_context.id, ( + "Calling thread does not hold the run lock" + ) + + def _get_recv_object(self, src: int, dst: int) -> object | None: + peers = [src] if src != -1 else list(self._ranks) + recv_objects = self._recv_objects[dst] + + for p in peers: + if not recv_objects[p].empty(): + return recv_objects[p].get() + + return None + + def _signal_send(self, src: int, dst: int, obj: object) -> None: + assert obj is not None, "Cannot signal None" + self._assert_holds_run_lock() + # Only a single thread a time executes so it is safe to mutate + # read objects queue (executing thread is already holding the lock) + self._recv_objects[dst][src].put(obj) + # Signal directly condition variable since the calling thread is already + # holding the lock + self._run_cond.notify_all() + + def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: + self._assert_holds_run_lock() + # Wait for the object to be available + while True: + obj = self._get_recv_object(src, dst) + if obj is not None: + post(obj) + # Note that we are not releasing the lock here, since the thread + # will continue to run and therefore must hold the lock + return + self._run_cond.wait() + + @staticmethod + def current() -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" + return _LOCAL_RUNNER_MODE diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 30b99931f251..0b63330dfafc 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -1,13 +1,15 @@ import functools import math import operator -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from datetime import timedelta import torch from torch._C import ScriptObject -from torch._C._distributed_c10d import FakeWork +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork from torch.distributed._mesh_layout import _MeshLayout from torch.distributed.distributed_c10d import ( + _check_op, _get_default_group, _resolve_process_group, ProcessGroup, @@ -765,10 +767,19 @@ def _local_send( # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int dst, int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like send. " - "Use SPMD collective operations instead." - ) + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + src = int(tensor.__src_rank__) + + LocalRunnerMode.current()._signal_send(src, dst, tensor._local_tensors[src]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so def _local_recv_( @@ -779,11 +790,26 @@ def _local_recv_( ) -> ScriptObject: # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int src, int tag) -> __torch__.torch.classes.c10d.Work"; + from . import LocalRunnerMode, LocalTensor - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv. " - "Use SPMD collective operations instead." - ) + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + dst = int(tensor.__src_rank__) + + def _recv_and_store(timeout: timedelta) -> bool: + def _wait_and_store(obj: object) -> None: + assert isinstance(obj, torch.Tensor), "Expected to receive a Tensor" + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + tensor._local_tensors[dst] = obj + + LocalRunnerMode.current()._wait_recv(src, dst, _wait_and_store) + return True + + work = PythonCallbackWork(_recv_and_store) + work_so = Work.boxed(work) + return work_so def _local_recv_any_source_( @@ -792,7 +818,60 @@ def _local_recv_any_source_( # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv_any_source. " - "Use SPMD collective operations instead." + return _local_recv_(tensors, process_group_so, -1, tag) + + +def _attach_rank(tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Attaches rank as an attribute to given tensor so that the send or recv implementation + knows which rank initiates the operation (note under local tensor mode ). + """ + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + + tensor.__src_rank__ = rank # type: ignore[attr-defined] + return tensor + + +def local_p2p_op( + dst: torch.SymInt, + tensor: torch.Tensor, + op: Callable[[torch.Tensor, int], Work | None], +) -> Work | None | list[Work | None]: + """ + Runs a point-to-point (P2P) operation for all combinations of source and destination ranks. + """ + _check_op(op) + + from . import LocalIntNode + + assert isinstance(dst.node, LocalIntNode), ( + "Expected 'dst' to be a LocalIntNode where the value is the destination rank and key is the source rank" ) + + w = [] + for s, d in dst.node._local_ints.items(): + tensor = _attach_rank(tensor, s) + w.append(op(tensor, d)) + return w + + +def wait_all(work: Work | None | list[Work | None]) -> None: + """ + Waits for all work objects in the input to complete. + + A single Work object, None, or a list of Work objects (possibly containing None). + If None, does nothing. If a single Work, waits for it to complete. If a list, waits + for each non-None Work in the list to complete. + """ + + if work is None: + return + if isinstance(work, Work): + work = [work] + for w in work: + if w is None: + continue + w.wait() diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index c13ba46ba575..8f7043453be7 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -145,7 +145,7 @@ def _streaming_load( if pickle_module is None: pickle_module = pickle - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" zip_file = _PseudoZipFile() diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 9db89d038658..9d70ab7c7400 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -257,7 +257,7 @@ def _process_output_file( ) # Process each input safetensors file - for safetensors_file in input_files_data.keys(): + for safetensors_file in input_files_data: file_metadata = input_files_data[safetensors_file].metadata input_metadata_size = input_files_data[safetensors_file].metadata_size diff --git a/torch/distributed/checkpoint/quantized_hf_storage.py b/torch/distributed/checkpoint/quantized_hf_storage.py index 2cb189d515a8..36f4ddf937fe 100644 --- a/torch/distributed/checkpoint/quantized_hf_storage.py +++ b/torch/distributed/checkpoint/quantized_hf_storage.py @@ -82,7 +82,7 @@ def _build_weight_scale_mapping(self, weight_map: dict[str, str]): # Store the complete weight map for file location lookups self._weight_map = weight_map - for tensor_name in weight_map.keys(): + for tensor_name in weight_map: if tensor_name.endswith(".weight_scale_inv"): weight_name = tensor_name.replace(".weight_scale_inv", ".weight") if weight_name in weight_map: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16d988a79103..54a29c0bb358 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -443,7 +443,7 @@ def _verify_state_dict( f"or load but optim state_dict is empty. {optim_state_dict}" ) - for key in model_state_dict.keys(): + for key in model_state_dict: if _FLAT_PARAM in key: raise RuntimeError( f"{key} contains {_FLAT_PARAM}. This can happen if the model " @@ -1007,7 +1007,14 @@ def _split_optim_state_dict( raise AssertionError(f"Expected list, got {type(params)}") params.append(fqn) if param.requires_grad: - state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + if fqn in cast(DictValueType, optim_state_dict[_STATE]): + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + elif info.strict: + raise RuntimeError( + f"Missing optimizer state for parameter '{fqn}' in checkpoint. " + "The parameter requires gradients but has no saved optimizer state. " + "To load anyway, use StateDictOptions(strict=False)." + ) for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bc79408a32ff..415cbacc177a 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3,6 +3,7 @@ import collections.abc import contextlib +import copy import ctypes import hashlib import io @@ -130,6 +131,7 @@ "reduce_scatter_tensor", "get_node_local_rank", "split_group", + "shrink_group", ] _MPI_AVAILABLE = True @@ -5211,7 +5213,9 @@ def split_group( if pg_options is None: # default pg_options same as the parent process group - pg_options = parent_backend.options + # A deep copy is needed because if the option will be modified inside split + # and if we split parent pg multiple times, we will run into device out of bound error. + pg_options = copy.deepcopy(parent_backend.options) # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, # which may just pass their timeout value (or None) @@ -5753,3 +5757,521 @@ def _get_process_group_name(pg: ProcessGroup) -> str: def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + + +# Shrink flags for process group backends +SHRINK_DEFAULT = 0x00 +SHRINK_ABORT = 0x01 + + +@_time_logger +def shrink_group( + ranks_to_exclude: list[int], + group: Optional[ProcessGroup] = None, + shrink_flags: int = SHRINK_DEFAULT, + pg_options: Optional[Any] = None, +) -> ProcessGroup: + """ + Shrinks a process group by excluding specified ranks. + + Creates and returns a new, smaller process group comprising only the ranks + from the original group that were not in the ``ranks_to_exclude`` list. + + Args: + ranks_to_exclude (List[int]): A list of ranks from the original + ``group`` to exclude from the new group. + group (ProcessGroup, optional): The process group to shrink. If ``None``, + the default process group is used. Defaults to ``None``. + shrink_flags (int, optional): Flags to control the shrinking behavior. + Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. + ``SHRINK_ABORT`` will attempt to terminate ongoing operations + in the parent communicator before shrinking. + Defaults to ``SHRINK_DEFAULT``. + pg_options (ProcessGroupOptions, optional): Backend-specific options to apply + to the shrunken process group. If provided, the backend will use + these options when creating the new group. If omitted, the new group + inherits defaults from the parent. + + Returns: + ProcessGroup: a new group comprised of the remaining ranks. If the + default group was shrunk, the returned group becomes the new default group. + + Raises: + TypeError: if the group’s backend does not support shrinking. + ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, + duplicates, or excludes all ranks). + RuntimeError: if an excluded rank calls this function or the backend + fails the operation. + + Notes: + - Only non-excluded ranks should call this function; excluded ranks + must not participate in the shrink operation. + - Shrinking the default group destroys all other process groups since + rank reassignment makes them inconsistent. + """ + # Step 1: Validate input parameters with comprehensive error checking + _validate_shrink_inputs(ranks_to_exclude, shrink_flags) + + # Step 2: Get target group and essential properties + target_group_info = _prepare_shrink_target_group(group) + + # Step 3: Validate backend requirements and availability + backend_impl = _validate_shrink_backend_requirements(target_group_info) + + # Step 4: Validate ranks against group and check for duplicates + excluded_ranks_set = _validate_and_process_excluded_ranks( + ranks_to_exclude, target_group_info + ) + + # Step 5: Execute the actual shrink operation (backend-specific) + new_backend = backend_impl.shrink( + sorted(excluded_ranks_set), + shrink_flags, + pg_options if pg_options is not None else None, + ) + + # Step 6: Handle cleanup and creation of new process group + target_group_info["pg_options_override"] = pg_options + return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) + + +def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: + """Validate input parameters for shrink_group.""" + if not isinstance(ranks_to_exclude, list): + raise TypeError( + f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " + f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." + ) + + if not ranks_to_exclude: + raise ValueError( + "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " + "one rank to exclude. Example: [failed_rank_id]" + ) + + # Validate shrink_flags with clear explanation of valid values + valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] + if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: + raise ValueError( + f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " + f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " + f"Use SHRINK_ABORT to abort ongoing operations before shrinking." + ) + + +def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: + """Prepare and validate the target group for shrinking.""" + target_pg = group if group is not None else _get_default_group() + + # Cache frequently accessed properties to avoid repeated calls + group_size = int(target_pg.size()) + group_info = { + "process_group": target_pg, + "is_default_group": (target_pg == _get_default_group()), + "group_size": group_size, + "current_rank": target_pg.rank(), + "group_name": _get_process_group_name(target_pg), + } + + # Validate that we have a valid process group + if group_size <= 1: + raise ValueError( + f"Cannot shrink a process group with size {group_size}. " + f"Group must have at least 2 ranks to support shrinking." + ) + + return group_info + + +def _validate_shrink_backend_requirements(group_info: dict) -> Any: + """Return the backend implementation for the target group or raise if unsupported.""" + target_pg = group_info["process_group"] + group_name = group_info["group_name"] + + # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, + # otherwise try CUDA then fall back to CPU. + try: + preferred_device = getattr(target_pg, "bound_device_id", None) + if preferred_device is not None: + backend_impl = target_pg._get_backend(preferred_device) + else: + # Try CUDA first if available, else CPU + try: + backend_impl = target_pg._get_backend(torch.device("cuda")) + except Exception: + backend_impl = target_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + f"Cannot access device backend for process group '{group_name}'. " + f"Ensure the process group was initialized with a compatible device backend and devices are available." + ) from e + + try: + supports = bool(backend_impl.supports_shrinking) + except Exception: + supports = False + if not supports: + raise TypeError( + f"Process group backend for '{group_name}' does not support shrinking operations." + ) + + return backend_impl + + +def _validate_and_process_excluded_ranks( + ranks_to_exclude: list[int], group_info: dict +) -> set: + """Validate excluded ranks and convert to set for efficient operations.""" + group_size = group_info["group_size"] + current_rank = group_info["current_rank"] + + # Use set for O(1) duplicate detection and membership testing + excluded_ranks_set = set() + + # Validate each rank with detailed error messages + for i, rank in enumerate(ranks_to_exclude): + if not isinstance(rank, int): + raise TypeError( + f"All elements in ranks_to_exclude must be integers. " + f"Element at index {i} is {type(rank).__name__}: {rank}" + ) + + if not (0 <= rank < group_size): + raise ValueError( + f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " + f"Valid ranks are in range [0, {group_size - 1}]." + ) + + if rank in excluded_ranks_set: + raise ValueError( + f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " + f"Each rank can only be excluded once." + ) + + excluded_ranks_set.add(rank) + + # Ensure we don't exclude all ranks + if len(excluded_ranks_set) >= group_size: + raise ValueError( + f"Cannot exclude all {group_size} ranks from process group. " + f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." + ) + + # Critical check: current rank should not be in excluded list + if current_rank in excluded_ranks_set: + raise RuntimeError( + f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " + f"Only non-excluded ranks should participate in the shrinking operation. " + f"Excluded ranks should terminate their processes instead." + ) + + return excluded_ranks_set + + +def _finalize_shrunk_group( + group_info: dict, excluded_ranks_set: set, new_backend +) -> ProcessGroup: + """Clean up old group and create new shrunk process group.""" + target_pg = group_info["process_group"] + is_default_group = group_info["is_default_group"] + + # Handle default group dependencies - destroy other groups first + if is_default_group: + _destroy_all_other_groups(exclude_group=target_pg) + + # Gather original group metadata before cleanup + original_group_metadata = _extract_group_metadata(target_pg) + + # Calculate remaining ranks efficiently + original_ranks = get_process_group_ranks(target_pg) + remaining_ranks = [ + rank for rank in original_ranks if rank not in excluded_ranks_set + ] + + # Clean up the original group + _cleanup_original_group(target_pg, is_default_group) + + # Create and configure the new process group + new_pg = _create_shrunk_process_group( + new_backend, remaining_ranks, original_group_metadata, is_default_group + ) + + # Register the new group in global state + if is_default_group: + _update_default_pg(new_pg) + + # Update global state with new group information + rank_mapping = { + global_rank: group_rank + for group_rank, global_rank in enumerate(remaining_ranks) + } + _update_process_group_global_state( + pg=new_pg, + backend_name=original_group_metadata["backend_name"], + store=original_group_metadata["store"], + group_name=original_group_metadata["new_group_name"], + backend_config=original_group_metadata["backend_config"], + rank_mapping=rank_mapping, + ) + + return new_pg + + +def _extract_group_metadata(target_pg: ProcessGroup) -> dict: + """Extract metadata from the original group before cleanup.""" + original_backend_name, original_store = _world.pg_map[target_pg] + original_backend_config = _world.pg_backend_config.get(target_pg, "") + original_group_name = _get_process_group_name(target_pg) + + # Extract device binding information before cleanup to avoid accessing destroyed group + bound_device_id = None + if hasattr(target_pg, "bound_device_id"): + bound_device_id = target_pg.bound_device_id + + # Generate new group name for the shrunk group; hash for uniqueness across backends + remaining_ranks = list(get_process_group_ranks(target_pg)) + new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) + + return { + "backend_name": original_backend_name, + "store": original_store, + "backend_config": original_backend_config, + "original_group_name": original_group_name, + "new_group_name": new_group_name, + "bound_device_id": bound_device_id, # Safe to access after cleanup + } + + +def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: + """Clean up the original process group safely.""" + try: + destroy_process_group(target_pg) + except Exception: + group_type = "default" if is_default_group else "non-default" + logger.warning( + "Failed to destroy %s group during shrinking", group_type, exc_info=True + ) + + # Ensure global state cleanup even if destroy_process_group fails + _cleanup_process_group_global_state(target_pg) + + +def _create_shrunk_process_group( + new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool +) -> ProcessGroup: + """Create and configure the new shrunk process group.""" + # Create new group properties + new_group_rank = new_backend.rank() + new_group_size = new_backend.size() + group_name = metadata["new_group_name"] + + # Generate descriptive group description + if is_default_group: + group_desc = "default:shrunken" + else: + group_desc = f"{metadata['original_group_name']}:shrunk" + + # Create process group with new communicator (clone the parent store like split does) + prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) + new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) + + # Configure backend using the device type of the new backend's bound device if available, + # otherwise derive from the original group's bound device or fall back to CPU. + backend_device = metadata.get("bound_device_id") + if backend_device is None: + # Default to CPU if no bound device is present + backend_device = torch.device("cpu") + + # Choose backend enum based on device type + if backend_device.type == "cuda": + backend_type = ProcessGroup.BackendType.NCCL + else: + backend_type = ProcessGroup.BackendType.GLOO + + new_pg._register_backend(backend_device, backend_type, new_backend) + new_pg._set_default_backend(backend_type) + + # Inherit device binding from original group if it was bound + bound_device_id = metadata.get("bound_device_id") + if bound_device_id is not None: + new_pg.bound_device_id = bound_device_id + + # Set group metadata + new_pg._set_group_name(group_name) + new_pg._set_group_desc(group_desc) + + # Persist backend configuration overrides (if provided via shrink_group) + backend_config_override = metadata.get("backend_config") + if backend_config_override is not None: + # Store for introspection/debugging and potential backend hooks + _world.pg_backend_config[new_pg] = backend_config_override + + return new_pg + + +def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: + """ + Destroy all process groups except the excluded group and clean up all global state. + + This is necessary when shrinking the default group because global ranks + are reassigned by NCCL, making all existing process groups inconsistent. + + Note: Uses abort for non-collective cleanup since excluded ranks may not + participate in collective operations. Backend cleanup is handled independently per group. + + Args: + exclude_group (ProcessGroup, optional): Process group to exclude from destruction. + If None, destroys all process groups. + """ + # Get list of groups to destroy (avoid modifying dict while iterating) + groups_to_destroy = [] + for pg in list(_world.pg_group_ranks.keys()): + if exclude_group is not None and pg == exclude_group: + continue + groups_to_destroy.append(pg) + + # Warn user about automatic destruction + if groups_to_destroy: + group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] + logger.warning( + "Shrinking default group will destroy %d other process groups: %s. " + "This is necessary because shrinking the default group reassigns global ranks, " + "making existing groups inconsistent.", + len(groups_to_destroy), + ", ".join(group_names), + ) + + # Destroy each group and clean up global state + for pg in groups_to_destroy: + try: + # First call abort_process_group which handles the C++ cleanup non-collectively + _abort_process_group(pg) + except Exception: + # Log but don't fail - some groups might already be destroyed + logger.warning( + "Failed to abort process group %s", + _get_process_group_name(pg), + exc_info=True, + ) + + # Ensure all global state is cleaned up even if _abort_process_group fails + # or doesn't clean up everything + _cleanup_process_group_global_state(pg) + + +def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: + """ + Clean up all global state associated with a process group. + + This function ensures complete cleanup of process group state from all + global dictionaries and registries, even if destroy_process_group fails + or doesn't clean up everything. This is critical when destroying multiple + groups to prevent inconsistent state. + + The cleanup removes the process group from: + - _world.pg_map (backend and store mapping) + - _world.pg_names (group name mapping) + - _world.pg_group_ranks (rank mappings) + - _world.pg_backend_config (backend configuration) + - _world.tags_to_pg and _world.pg_to_tag (tag mappings) + - _world.pg_coalesce_state (coalescing state) + - C++ internal registries via _unregister_process_group + + Args: + pg (ProcessGroup): The process group to clean up. + """ + try: + # Clean up main process group mappings + _world.pg_map.pop(pg, None) + _world.pg_group_ranks.pop(pg, None) + _world.pg_backend_config.pop(pg, None) + + # Clean up process group name mapping + group_name = _world.pg_names.pop(pg, None) + + # Clean up tag mappings + pg_tag = _world.pg_to_tag.pop(pg, None) + if pg_tag is not None and pg_tag in _world.tags_to_pg: + try: + _world.tags_to_pg[pg_tag].remove(pg) + # Remove the tag entry if list is empty + if not _world.tags_to_pg[pg_tag]: + _world.tags_to_pg.pop(pg_tag, None) + except (ValueError, KeyError): + # Process group was already removed from the list + pass + + # Clean up any registered process group names using C++ unregister function + if group_name is not None: + try: + _unregister_process_group(group_name) + except Exception: + # Process group name might not be registered or already unregistered + pass + + # Clean up coalesce state if present + _world.pg_coalesce_state.pop(pg, None) + + except Exception: + # Log cleanup failures but don't propagate - we want to continue with other cleanups + logger.warning( + "Failed to fully clean up global state for process group", exc_info=True + ) + + +def _update_process_group_global_state( + pg: ProcessGroup, + backend_name: str, + store: Store, + group_name: str, + backend_config: str, + rank_mapping: Optional[dict[int, int]] = None, + pg_tag: Optional[str] = None, + user_tag: Optional[str] = None, +) -> None: + """ + Update all global state dictionaries for a process group. + + This helper function consolidates the common pattern of updating multiple + global state dictionaries when creating or modifying process groups. + + Args: + pg (ProcessGroup): The process group to update state for. + backend_name (str): Backend name for pg_map. + store (Store): Store instance for pg_map. + group_name (str): Group name for pg_names and registration. + backend_config (str): Backend configuration string. + rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. + If None, skips updating pg_group_ranks. + pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". + user_tag (str, optional): User-provided tag for special tag handling. + If provided, creates "user:{user_tag}" tag and also adds to default "". + """ + # Update main process group mappings + _world.pg_map[pg] = (backend_name, store) + _world.pg_names[pg] = group_name + _world.pg_backend_config[pg] = backend_config + + # Register the process group name + _register_process_group(group_name, pg) + + # Update rank mapping if provided + if rank_mapping is not None: + _world.pg_group_ranks[pg] = rank_mapping + + # Handle tag management + if pg_tag is None: + pg_tag = f"ptd:{group_name}" + + if user_tag is not None: + # Special handling for user-provided tags + # Add to default "" tag first + _world.tags_to_pg.setdefault("", []).append(pg) + # Then create user-specific tag + user_pg_tag = f"user:{user_tag}" + _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) + _world.pg_to_tag[pg] = user_pg_tag + else: + # Standard process group tag + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7ad35115cd34..a34ec1408be5 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,9 +10,10 @@ import logging import os import time +from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union +from typing import Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -129,7 +130,7 @@ def __init__( self._log_line_prefixes = log_line_prefixes self._log_line_filter = log_line_filter self._finished_events: dict[int, Event] = { - local_rank: Event() for local_rank in log_files.keys() + local_rank: Event() for local_rank in log_files } self._futs: list[Future] = [] self._interval_sec = interval_sec diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 60e3f37a9991..96657eeea410 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1549,7 +1549,7 @@ def _allgather_orig_param_states( fsdp_state._device_handle.memory_summary(), ) - output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states} dtype, state_buffers = _convert_all_state_info( fsdp_param_info, gathered_state_info, input_states, output_states diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 39da483fe002..44569427f8db 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1637,7 +1637,7 @@ def _step_microbatches( # the stages in the pipeline_order all_prev_ranks: set[int] = set() all_next_ranks: set[int] = set() - for stage_index in stage_index_to_stage.keys(): + for stage_index in stage_index_to_stage: # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) if stage_index > 0: all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) @@ -2033,12 +2033,6 @@ def _perform_action(action: _Action) -> None: is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage - logger.debug( - "_PipelineScheduleRuntime running time_step %d, action %s", - time_step, - action, - ) - # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be # safe to use instead. @@ -2191,6 +2185,11 @@ def _perform_action(action: _Action) -> None: # count either full_backward or backward_weight together, to determine when to sync DP grads self.backward_counter.clear() for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) try: with record_function(_get_profiler_function_name(action)): if action.computation_type in self._comp_type_to_function_map: @@ -3176,7 +3175,7 @@ def get_schedule_class(schedule_name: str): "ZBVZeroBubble": ScheduleZBVZeroBubble, "DualPipeV": ScheduleDualPipeV, } - lowercase_keys = {k.lower(): k for k in schedule_map.keys()} + lowercase_keys = {k.lower(): k for k in schedule_map} lowercase_schedule_name = schedule_name.lower() if lowercase_schedule_name not in lowercase_keys: raise ValueError( diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 084fa62706e0..53b759e993c0 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -618,7 +618,7 @@ def common_pointwise_strategy( return pointwise_strategy -for op in linear_pointwise_ops.keys(): +for op in linear_pointwise_ops: register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( linear_pointwise_strategy ) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 2444467a3595..f238739ddd5c 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -10,6 +10,7 @@ _enable_context_parallel_dispatcher, _is_causal_behavior, _RotateMethod, + _templated_ring_attention, context_parallel, context_parallel_unshard, set_rotate_method, @@ -22,6 +23,7 @@ ) +# TODO(fegin): add deprecation message once the final interfaces are concluded. __all__ = [ "_CausalBehavior", "_context_parallel_shard", @@ -31,6 +33,7 @@ "_enable_context_parallel_dispatcher", "_is_causal_behavior", "_RotateMethod", + "_templated_ring_attention", "context_parallel", "context_parallel_unshard", "set_rotate_method", diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 1e1f1f409857..a9a018468cef 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes( roots.add(c.root.__name__) # type: ignore[attr-defined] # check keys are existing dims or new roots - for k in shape_fixes.keys(): + for k in shape_fixes: assert k in name_to_dim or k in roots # cache so we don't produce multiple derived dim objects diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index d07d235e5132..e01cab57775c 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1871,7 +1871,7 @@ def round_magic_impl(self, ndigits=None): setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) -for method in magic_methods.keys(): # type: ignore[assignment] +for method in magic_methods: # type: ignore[assignment] if method in only_bool_magic_methods: _make_user_magic(method, SymBool) continue diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000d..693d25aea613 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index bab662e0655a..efafb146179a 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -61,7 +61,7 @@ def substitute_solution_one_type(mapping, t): Apply the most general unifier to a type """ if isinstance(t, Var): - if t in mapping.keys(): + if t in mapping: return mapping[t] else: return t @@ -69,7 +69,7 @@ def substitute_solution_one_type(mapping, t): elif isinstance(t, TensorType): new_type = [] for typ in t.__args__: - if typ in mapping.keys(): + if typ in mapping: new_type.append(mapping[typ]) else: new_type.append(typ) @@ -102,7 +102,7 @@ def substitute_all_types(graph, mapping): flag = False for k in mapping: old_mapping_val = mapping[k] - if mapping[k] in mapping.keys(): + if mapping[k] in mapping: new_key = mapping[k] mapping[k] = mapping[new_key] if old_mapping_val != mapping[k]: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fc6f4c5b2702..d8cfa42472b4 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -226,8 +226,10 @@ class PythonCode: # Values in global scope during execution of `src_def`. globals: dict[str, Any] # Optional mapping from the forward function's line number to - # node index. + # node index. Line number starts at the prologue (i.e. forward()). _lineno_map: Optional[dict[int, Optional[int]]] + # The line number of prologue in fn_code + _prologue_start: int = 0 def _format_target(base: str, target: str) -> str: @@ -441,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -645,6 +648,15 @@ def emit_node(node: Node): if verbose: # override annotation with more detailed information + try: + from torch.distributed.tensor._api import DTensor, DTensorSpec + + dtensorspec_format_shard_order_str = ( + DTensorSpec.format_shard_order_str + ) + except ModuleNotFoundError: + DTensor = None # type: ignore[assignment,misc] + dtensorspec_format_shard_order_str = None from torch.fx.experimental.proxy_tensor import py_sym_types from torch.fx.passes.shape_prop import TensorMetadata @@ -675,6 +687,16 @@ def _tensor_annotation(t: torch.Tensor) -> str: core = _tensor_annotation(meta_val) if is_plain: maybe_type_annotation = f': "{core}"' + elif type(meta_val) is DTensor: + assert dtensorspec_format_shard_order_str is not None + dtensor_meta = dtensorspec_format_shard_order_str( + meta_val._spec.placements, # type: ignore[attr-defined] + meta_val._spec.shard_order, # type: ignore[attr-defined] + ) + cls = meta_val.__class__.__name__ + maybe_type_annotation = ( + f': "{cls}({core}, {dim_green(dtensor_meta)})"' + ) else: cls = meta_val.__class__.__name__ maybe_type_annotation = f': "{cls}({core})"' @@ -796,6 +818,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -805,8 +831,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -854,7 +894,14 @@ def _tensor_annotation(t: torch.Tensor) -> str: {prologue} {code}""" - return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + # The +4 accounts for the empty lines before prologue in fn_code + prologue_start = wrap_stmts.count("\n") + 4 + return PythonCode( + fn_code, + globals_, + _lineno_map=lineno_map, + _prologue_start=prologue_start, + ) # Ideally, we'd like to refactor all of the pytree logic into this codegen @@ -1098,7 +1145,7 @@ def find_nodes(self, *, op: str, target: Optional["Target"] = None): return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module - return [node for node in self.table[(op, None)].keys() if node.target == target] + return [node for node in self.table[(op, None)] if node.target == target] @compatibility(is_backward_compatible=True) @@ -1751,6 +1798,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1818,6 +1866,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1830,6 +1879,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1840,6 +1890,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 159926bc8ba4..8360c96630d6 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +import base64 import contextlib import copy +import hashlib import itertools import linecache import os @@ -36,6 +38,7 @@ ] _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" +FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_" # Normal exec loses the source code, however we can work with @@ -61,7 +64,13 @@ def cache(self, src: str, globals: dict[str, Any], co_fields=None): key = self._get_key() if co_fields: - key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + if "co_filename" in co_fields: + # If only co_filename is provided, use it directly as the key + if "co_firstlineno" not in co_fields or "co_name" not in co_fields: + key = co_fields["co_filename"] + else: + # Full co_fields with all three components + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" self.eval_cache[key] = src # Don't mutate globals so that this loader is only used @@ -353,6 +362,36 @@ def _print_readable( return output +def _metadata_hash(code: str, node_metadata: dict) -> str: + """ + Create a content-addressed hash from code and metadata. + + Args: + code: The source code string + lineno_map: Mapping from line numbers to node indices + node_metadata: Metadata for each node + + Returns: + A 51-character base32-encoded hash + """ + import json + + # Create a deterministic string representation of all components + # We use JSON to ensure consistent serialization + hash_data = { + "code": code, + "node_metadata": node_metadata, + } + hashing_str = json.dumps(hash_data).encode("utf-8") + + # [:51] to strip off the "Q====" suffix common to every hash value. + return ( + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51] + .decode("utf-8") + .lower() + ) + + class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls @@ -822,12 +861,60 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map + self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + + if dynamo_config.enrich_profiler_metadata: + # Generate metadata and register for profiler augmentation + node_metadata: dict[int, dict[str, Any]] = {} + for i, node in enumerate(self._graph.nodes): + node_metadata[i] = { + "name": node.name, + "op": node.op, + "target": str(node.target), + "stack_trace": node.meta.get("stack_trace", None), + } + + # Generate a content-addressed filename based on hash of code and metadata + # This ensures the same code+metadata always generates the same filename + hash_value = _metadata_hash(self._code, node_metadata) + file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" + + # Only include co_filename to use it directly as the cache key + co_fields = { + "co_filename": filename, + } + + # Store metadata in global in-memory registry + metadata = { + "lineno_map": python_code._lineno_map, + "prologue_start": python_code._prologue_start, + "node_metadata": node_metadata, + } + + # Register metadata in the global registry + from torch.fx.traceback import _register_fx_metadata + + _register_fx_metadata(filename, metadata) + + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a3114a14a657..5b40e8a66147 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs import inspect +import logging from contextlib import contextmanager from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.fx.traceback as fx_traceback -from torch._logging import trace_structured +from torch._logging import LazyString, trace_structured from torch.hub import tqdm from . import config @@ -21,10 +22,35 @@ if TYPE_CHECKING: from collections.abc import Iterator +log = logging.getLogger(__name__) __all__ = ["Interpreter", "Transformer"] +def _format_fx_node(n): + """ + Format a torch.fx.Node into a human-readable string for debug logging. + + Args: + n (torch.fx.Node): The FX node being executed. + + Returns: + str: A formatted string describing the node operation, including its + name, target, positional arguments, and keyword arguments. + """ + module_prefix = getattr(n.target, "__module__", "") + module_prefix = f"{module_prefix}." if module_prefix else "" + + # Handle positional and keyword arguments + args = ", ".join(map(str, n.args)) + kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items()) + joined = ", ".join(filter(None, [args, kwargs])) + + return ( + f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})" + ) + + @compatibility(is_backward_compatible=True) class Interpreter: """ @@ -220,11 +246,23 @@ def boxed_run(self, args_list): calling convention, where you pass a list of arguments, which will be cleared by the interpreter. This ensures that input tensors are promptly deallocated. """ - args_iter = iter(args_list) - env = {} - for n in self.graph.nodes: - if n.op == "placeholder": - env[n] = next(args_iter) + # Collect placeholder nodes first + placeholder_nodes = [n for n in self.graph.nodes if n.op == "placeholder"] + + # Check argument count + if len(args_list) != len(placeholder_nodes): + detail = ( + "extra arguments" + if len(args_list) > len(placeholder_nodes) + else "missing arguments" + ) + raise RuntimeError( + f"Interpreter.boxed_run expected {len(placeholder_nodes)} arguments for placeholders " + f"but received {len(args_list)} ({detail})" + ) + + # Assign arguments to placeholders + env = dict(zip(placeholder_nodes, args_list)) args_list.clear() return self.run(initial_env=env) @@ -249,6 +287,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ + log.debug("run_node %s", LazyString(lambda: _format_fx_node(n))) with self._set_current_node(n): args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) diff --git a/torch/fx/node.py b/torch/fx/node.py index 1d72a75a6ccf..272676a4e3a9 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -496,7 +496,7 @@ def insert_arg(self, idx: int, arg: Argument) -> None: _new_input_nodes: dict[Node, None] = {} _fx_map_arg(arg, _new_input_nodes.setdefault) - for new_use in _new_input_nodes.keys(): + for new_use in _new_input_nodes: if new_use not in self._input_nodes: self._input_nodes.setdefault(new_use) new_use.users.setdefault(self) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 58aa80106282..1d3b0b33e7bc 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -373,11 +373,9 @@ def has_new_untracked_symbols(): shape_env, node.meta.get("unbacked_bindings", {}) ) - assert resolved_unbacked_bindings is not None - def has_new_unbacked_bindings(): - # pyrefly: ignore [missing-attribute] - for key in resolved_unbacked_bindings.keys(): + assert resolved_unbacked_bindings is not None + for key in resolved_unbacked_bindings: if key not in expr_to_proxy: return True return False diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 6cf708a61906..8d90f9d55cfd 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -204,7 +204,7 @@ def to_dict(self): Create dict dump on all events. """ ret: dict[str, list[str]] = {} - for name in self.node_events.keys(): + for name in self.node_events: ret[name] = [] for idx in self.node_events.get(name, []): event = self.events[idx] @@ -218,7 +218,7 @@ def print_all(self, writer=None): """ if not writer: writer = self.writer - for name in self.node_events.keys(): + for name in self.node_events: writer(f"Node: {name}:") self.print_node(name, recursive=False, tab=" ", writer=writer) diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 043c65e6b77d..82259b8a36ab 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -113,7 +113,7 @@ def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: # get_attr nodes won't be output nodes continue - for user in node.users.keys(): + for user in node.users: if user not in nodes: output_nodes.add(node) @@ -157,7 +157,7 @@ def check_subgraphs_connected( """ for node in reversed(subgraph1.nodes): - for user in node.users.keys(): + for user in node.users: if user in subgraph2.nodes: return True return False diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a143119cd78b..25fb81a5aa01 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -38,6 +38,28 @@ current_replay_node: Optional[Node] = None should_preserve_node_meta = False +# ============================================================================= +# FX Metadata Registry for Memory Profiler +# ============================================================================= +# Global in-memory registry for FX metadata +# Maps module_name -> metadata dict containing lineno_map and node_metadata +_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {} + + +def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None: + """ + Register FX metadata in the global in-memory registry. + + This is called automatically during graph module compilation to store metadata + for later use by memory profiler augmentation. + + Args: + module_name: The module identifier (content-addressed filename) + metadata: Metadata dict containing lineno_map, node_metadata, and source_code + """ + # TODO: add logging to tlparse + _FX_METADATA_REGISTRY[module_name] = metadata + @compatibility(is_backward_compatible=False) class NodeSourceAction(Enum): diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 70165a7493e5..c0cd5d9a2c68 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -42,6 +42,9 @@ fp16_ieee_to_fp32_value # fp32_from_bits called from fp16_ieee_to_fp32_value # fp32_to_bits called from fp16_ieee_from_fp32_value +# torch/headeronly/util/HeaderOnlyArrayRef.h +HeaderOnlyArrayRef + # c10/util/complex.h, torch/headeronly/util/complex.h complex diff --git a/torch/headeronly/util/HeaderOnlyArrayRef.h b/torch/headeronly/util/HeaderOnlyArrayRef.h new file mode 100644 index 000000000000..2387578ab8f5 --- /dev/null +++ b/torch/headeronly/util/HeaderOnlyArrayRef.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only +/// in headers. This will be a base class from which ArrayRef inherits, so that +/// we can keep much of the implementation shared. +/// +/// [HeaderOnlyArrayRef vs ArrayRef note] +/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less +/// functionality than ArrayRef. We document the minor differences below: +/// 1. ArrayRef has an extra convenience constructor for SmallVector. +/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only +/// STD_TORCH_CHECK, which will output a std::runtime_error vs a +/// c10::Error. Consequently, you should use ArrayRef when possible +/// and HeaderOnlyArrayRef only when necessary to support headeronly code. +/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the +/// positive benefit of being header-only and thus independent of libtorch.so. +template +class HeaderOnlyArrayRef { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + protected: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty HeaderOnlyArrayRef. + /* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {} + + /// Construct a HeaderOnlyArrayRef from a single element. + // TODO Make this explicit + constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct a HeaderOnlyArrayRef from a pointer and length. + constexpr HeaderOnlyArrayRef(const T* data, size_t length) + : Data(data), Length(length) {} + + /// Construct a HeaderOnlyArrayRef from a range. + constexpr HeaderOnlyArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) {} + + template < + typename Container, + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> + /* implicit */ HeaderOnlyArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) {} + + /// Construct a HeaderOnlyArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ HeaderOnlyArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "HeaderOnlyArrayRef cannot be constructed from a std::vector bitfield."); + } + + /// Construct a HeaderOnlyArrayRef from a std::array + template + /* implicit */ constexpr HeaderOnlyArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N]) + : Data(Arr), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a std::initializer_list. + /* implicit */ constexpr HeaderOnlyArrayRef( + const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return this->Data; + } + constexpr iterator end() const { + return this->Data + this->Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return this->Data; + } + constexpr const_iterator cend() const { + return this->Data + this->Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return this->Length == 0; + } + + constexpr const T* data() const { + return this->Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return this->Length; + } + + /// front - Get the first element. + constexpr const T& front() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access front() of empty list"); + return this->Data[0]; + } + + /// back - Get the last element. + constexpr const T& back() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(HeaderOnlyArrayRef RHS) const { + return this->Length == RHS.Length && + std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr HeaderOnlyArrayRef slice(size_t N, size_t M) const { + STD_TORCH_CHECK( + N + M <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + this->size()); + return HeaderOnlyArrayRef(this->data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr HeaderOnlyArrayRef slice(size_t N) const { + STD_TORCH_CHECK( + N <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return this->Data[Index]; + } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + STD_TORCH_CHECK( + Index < this->Length, + "HeaderOnlyArrayRef: invalid index Index = ", + Index, + "; Length = ", + this->Length); + return this->Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(this->Data, this->Data + this->Length); + } + + /// @} +}; + +} // namespace c10 + +namespace torch::headeronly { +using c10::HeaderOnlyArrayRef; +using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef; +} // namespace torch::headeronly diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 3a2b3ef8b600..343871b1f94a 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -574,7 +574,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. - for name in concrete_type.get_attributes().keys(): + for name in concrete_type.get_attributes(): orig_value = getattr(nn_module, name) orig_value = ( orig_value.value diff --git a/torch/jit/_script.py b/torch/jit/_script.py index a8bb3ba9bd8f..46e6f4753410 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -856,7 +856,7 @@ def __setattr__(self, attr, value): self._c.setattr(attr, value) elif ( hasattr(self, "_concrete_type") - and attr in self._concrete_type.get_constants().keys() + and attr in self._concrete_type.get_constants() ): # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. # We should encode constants as class type attributes (or something) so it persists across save/load. diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index a84a5b681d63..69c324ab726e 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -143,7 +143,7 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None: name, arg_type = named_arg_type.split(": ") is_optional = arg_type.endswith("?") normalized_arg_type = arg_type[:-1] if is_optional else arg_type - if normalized_arg_type not in arg_type_check_fns.keys(): + if normalized_arg_type not in arg_type_check_fns: raise AssertionError(f"Unknown arg type: {normalized_arg_type}") if i >= len(args): diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5e6e0fa5fae3..a115d32c6e2c 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -90,7 +90,7 @@ def _cur_sdpa_kernel_backends(with_priority: bool = False): return backends -def _sdpa_kernel(backends: Iterable, set_priority: bool = False): +def _sdpa_kernel(backends: Iterable, set_priority: bool = False) -> None: for name, val in _backend_names.items(): enabled = getattr(SDPBackend, val) in backends getattr(torch._C, f"_set_sdp_use_{name}")(enabled) diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index a91045b92c13..86f7c29f5313 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -40,7 +40,7 @@ def _validate_sdpa_input( dropout_p=0.0, is_causal=False, scale=None, -): +) -> None: if query.dtype != key.dtype or query.dtype != value.dtype: raise ValueError( f"Expected query, key, and value to have the same dtype, " diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 551a57e6963e..0cb256ad36f7 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -117,7 +117,7 @@ class CausalBias(torch.Tensor): .. warning:: This class is a prototype and subject to change. """ - def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None: """ Initializes the CausalBias instance with a specified variant and sequence lengths. @@ -296,7 +296,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return cls._dispatch(*args, **kwargs) return super().__torch_function__(func, types, args, kwargs) - def __repr__(self): # type:ignore[override] + def __repr__(self) -> str: # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 70eadcdadfaa..2e0ded6063ae 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -40,7 +40,7 @@ def __init__( page_size: int, max_batch_size: int, device: str = "cuda", - ): + ) -> None: # number of pages self.n_pages = n_pages diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index b79b86a29afb..be49549e5740 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -550,7 +550,7 @@ def __init__( full_q_indices: Optional[Tensor], BLOCK_SIZE: tuple[int, int], mask_mod: _mask_mod_signature, - ): + ) -> None: if kv_indices.dim() < 2: raise RuntimeError("BlockMask must have at least 2 dimensions") assert kv_num_blocks is not None, "kv_num_blocks must be provided" @@ -682,7 +682,7 @@ def shape(self): *batch_dims, _, _ = self.kv_indices.shape return tuple(batch_dims) + self.seq_lengths - def __str__(self): + def __str__(self) -> str: s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" mask_str = self.to_string().strip() s += mask_str @@ -760,7 +760,7 @@ def causal_mask(b, h, q_idx, kv_idx): compute_q_blocks=self.q_indices is not None, ) - def __repr__(self): + def __repr__(self) -> str: def shape_or_none(x: Optional[torch.Tensor]): return x.shape if x is not None else None @@ -864,7 +864,7 @@ def create_block_vis(*batch_idx): vis = ", ".join(reversed(descriptors)) + "\n" - def summarize_section(section): + def summarize_section(section) -> str: percentage = section.float().mean().item() if percentage == 1: return "β–ˆ" @@ -1289,7 +1289,7 @@ def _apply_kernel_options( return kernel_options -def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): +def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None: if query.size(-1) != key.size(-1): raise ValueError( f"Expect query and key/value to have the same embedding dimension " @@ -1297,7 +1297,7 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): ) -def _validate_device(query: Tensor, key: Tensor, value: Tensor): +def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None: """TODO: Remove once non cuda/cpu devices support is added We only need to check query since we have already that q,k,v are on the same device """ diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 8564153ece23..c56e923a8438 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -2,5 +2,5 @@ # this is for historical pickle deserialization, it is not used otherwise -def _get_thnn_function_backend(): +def _get_thnn_function_backend() -> None: pass diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index e447284ad82b..b4ffd188cd39 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -14,7 +14,7 @@ class OrderedDictWrapper: so using properties does not work. """ - def __init__(self, cpp_module, attr): + def __init__(self, cpp_module, attr) -> None: self.cpp_module = cpp_module self.attr = attr @@ -37,10 +37,10 @@ def values(self): def __iter__(self): return self.cpp_dict.__iter__() - def __len__(self): + def __len__(self) -> int: return self.cpp_dict.__len__() - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.cpp_dict.__contains__(key) def __getitem__(self, key): @@ -50,7 +50,7 @@ def __getitem__(self, key): class ModuleWrapper(nn.Module): """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" - def __init__(self, cpp_module): + def __init__(self, cpp_module) -> None: # Assign before the super class constructor so ``self.training`` can be # assigned to in the super class constructor. self.cpp_module = cpp_module @@ -83,8 +83,8 @@ def training(self): return self.cpp_module.training @training.setter - def training(self, mode): + def training(self, mode) -> None: self.cpp_module.train(mode) - def __repr__(self): + def __repr__(self) -> str: return self.cpp_module.__repr__() diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f7e3d2f262de..33bf35a1d852 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2521,7 +2521,7 @@ def _load_from_state_dict( unexpected_keys.append(extra_state_key) if strict: - for key in state_dict.keys(): + for key in state_dict: if key.startswith(prefix) and key != extra_state_key: input_name = key[len(prefix) :].split(".", 1) # Must be Module if it have attributes @@ -3040,7 +3040,7 @@ def _replicate_for_data_parallel(self): return replica - def compile(self, *args, **kwargs): + def compile(self, *args, **kwargs) -> None: """ Compile this Module's forward using :func:`torch.compile`. diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 9a0f4973d31b..9aaa9b4a92e6 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -30,7 +30,7 @@ def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: device_ids = [_get_device_index(x, True) for x in device_ids] dev_props = _get_devices_properties(device_ids) - def warn_imbalance(get_prop): + def warn_imbalance(get_prop) -> bool: values = [get_prop(props) for props in dev_props] min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c03c85f48fc3..64e9d8c2d80f 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -18,7 +18,7 @@ # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Parameter: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_param", False @@ -82,7 +82,7 @@ def __deepcopy__(self, memo): return result # pyrefly: ignore [bad-override] - def __repr__(self): + def __repr__(self) -> str: return "Parameter containing:\n" + super().__repr__() def __reduce_ex__(self, proto): @@ -125,7 +125,7 @@ class UninitializedTensorMixin: torch._has_compatible_shallow_copy_type, ] - def materialize(self, shape, device=None, dtype=None): + def materialize(self, shape, device=None, dtype=None) -> None: r"""Create a Parameter or Tensor with the same properties of the uninitialized one. Given a shape, it materializes a parameter in the same device @@ -163,7 +163,7 @@ def share_memory_(self): "`module.share_memory()`." ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" def __reduce_ex__(self, proto): @@ -235,7 +235,7 @@ def __deepcopy__(self, memo): # Metaclass to combine _TensorMeta and the instance check override for Buffer. class _BufferMeta(torch._C._TensorMeta): # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Buffer: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_buffer", False diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index a17821c2b16c..3d1cddb7e8b8 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -25,7 +25,7 @@ class Buffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... class UninitializedBuffer(Tensor): persistent: bool @@ -34,7 +34,7 @@ class UninitializedBuffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... def materialize( self, shape: tuple[int, ...], diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index cfb1d99ac30e..58ef67e06148 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -37,10 +37,10 @@ # all of the RNN decomps run linear with the batch dimension second, even if batch_first was set @contextmanager def batch_second(args, kwargs): - def set_batch_second(ew): + def set_batch_second(ew) -> None: ew.set_batch_first(False) - def reset_batch_first(ew): + def reset_batch_first(ew) -> None: ew.set_batch_first(True) tree_map_only(ExpandedWeight, set_batch_second, args) @@ -55,10 +55,10 @@ def reset_batch_first(ew): # to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch @contextmanager def allow_smaller_batches(args, kwargs): - def allow(ew): + def allow(ew) -> None: ew.set_allow_smaller_batches(True) - def reset(ew): + def reset(ew) -> None: ew.set_allow_smaller_batches(False) tree_map_only(ExpandedWeight, allow, args) @@ -102,7 +102,7 @@ def decorator(autograd_func): # # Needs to be a tensor subclass to allow reparameterization class ExpandedWeight(torch.Tensor): - def __init__(self, orig_weight, batch_size, loss_reduction): + def __init__(self, orig_weight, batch_size, loss_reduction) -> None: self.batch_size = batch_size self.batch_first = True self.allow_smaller_batches = False @@ -179,8 +179,8 @@ def data_ptr(self): def get_device(self): return self.orig_weight.get_device() - def set_allow_smaller_batches(self, is_allow_smaller_batches): + def set_allow_smaller_batches(self, is_allow_smaller_batches) -> None: self.allow_smaller_batches = is_allow_smaller_batches - def set_batch_first(self, is_batch_first=True): + def set_batch_first(self, is_batch_first=True) -> None: self.batch_first = is_batch_first diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index ec6d55305fb4..eacd717873ec 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -123,7 +123,7 @@ def maybe_scale_by_batch_size(grad_sample, expanded_weight): return grad_sample -def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn) -> None: unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) if isinstance(maybe_expanded_weight, ExpandedWeight): grad_sample_contribution = maybe_scale_by_batch_size( diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 7706be61e39f..59044b72b96c 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -388,7 +388,7 @@ def _weight_norm_compat_hook( missing_keys, unexpected_keys, error_msgs, - ): + ) -> None: g_key = f"{prefix}{name}_g" v_key = f"{prefix}{name}_v" if g_key in state_dict and v_key in state_dict: diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 88eeb3aaf50c..b9a1140e43f7 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -72,7 +72,7 @@ def cached(): _cache = {} -def _register_parameter_or_buffer(module, name, X): +def _register_parameter_or_buffer(module, name, X) -> None: if isinstance(X, Parameter): module.register_parameter(name, X) else: diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 3c1a80008595..827bf19ed4be 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -231,7 +231,7 @@ def prune(self, t, default_mask=None, importance_scores=None): default_mask = default_mask if default_mask is not None else torch.ones_like(t) return t * self.compute_mask(importance_scores, default_mask=default_mask) - def remove(self, module): + def remove(self, module) -> None: r"""Remove the pruning reparameterization from a module. The pruned parameter named ``name`` remains permanently pruned, @@ -269,7 +269,7 @@ class PruningContainer(BasePruningMethod): them. """ - def __init__(self, *args): + def __init__(self, *args) -> None: self._pruning_methods: tuple[BasePruningMethod, ...] = () if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name @@ -284,7 +284,7 @@ def __init__(self, *args): for method in args: self.add_pruning_method(method) - def add_pruning_method(self, method): + def add_pruning_method(self, method) -> None: r"""Add a child pruning ``method`` to the container. Args: @@ -303,7 +303,7 @@ def add_pruning_method(self, method): # if all checks passed, add to _pruning_methods tuple self._pruning_methods += (method,) # type: ignore[operator] - def __len__(self): + def __len__(self) -> int: return len(self._pruning_methods) def __iter__(self): @@ -449,7 +449,7 @@ class RandomUnstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -506,7 +506,7 @@ class L1Unstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -574,7 +574,7 @@ class RandomStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, dim=-1): + def __init__(self, amount, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -682,7 +682,7 @@ class LnStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, n, dim=-1): + def __init__(self, amount, n, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -799,7 +799,7 @@ def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: i class CustomFromMask(BasePruningMethod): PRUNING_TYPE = "global" - def __init__(self, mask): + def __init__(self, mask) -> None: self.mask = mask def compute_mask(self, t, default_mask): @@ -1025,7 +1025,9 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None): return module -def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): +def global_unstructured( + parameters, pruning_method, importance_scores=None, **kwargs +) -> None: r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -1212,7 +1214,7 @@ def remove(module, name): ) -def is_pruned(module): +def is_pruned(module) -> bool: r"""Check if a module is pruned by looking for pruning pre-hooks. Check whether ``module`` is pruned by looking for @@ -1241,7 +1243,7 @@ def is_pruned(module): return False -def _validate_pruning_amount_init(amount): +def _validate_pruning_amount_init(amount) -> None: r"""Validate helper to check the range of amount at init. Args: @@ -1271,7 +1273,7 @@ def _validate_pruning_amount_init(amount): ) -def _validate_pruning_amount(amount, tensor_size): +def _validate_pruning_amount(amount, tensor_size) -> None: r"""Validate that the pruning amount is meaningful wrt to the size of the data. Validation helper to check that the amount of parameters to prune @@ -1295,7 +1297,7 @@ def _validate_pruning_amount(amount, tensor_size): ) -def _validate_structured_pruning(t): +def _validate_structured_pruning(t) -> None: r"""Validate that the tensor to be pruned is at least 2-Dimensional. Validation helper to check that the tensor to be pruned is multi- @@ -1342,7 +1344,7 @@ def _compute_nparams_toprune(amount, tensor_size): return round(amount * tensor_size) -def _validate_pruning_dim(t, dim): +def _validate_pruning_dim(t, dim) -> None: r"""Validate that the pruning dimension is within the bounds of the tensor dimension. Args: diff --git a/torch/onnx/_internal/exporter/_dynamic_shapes.py b/torch/onnx/_internal/exporter/_dynamic_shapes.py index e128ecf74e9e..888db138736f 100644 --- a/torch/onnx/_internal/exporter/_dynamic_shapes.py +++ b/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -67,7 +67,7 @@ def from_dynamic_axes_to_dynamic_shapes( # output names are not needed for dynamic_shapes continue if isinstance(axes, dict): - if any(not isinstance(k, int) for k in axes.keys()): + if any(not isinstance(k, int) for k in axes): raise ValueError( "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." ) diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 4def193daf19..c417b354429b 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -32,7 +32,7 @@ def __init__( *, foreach: Optional[bool] = None, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -77,7 +77,7 @@ def _init_group( col_vars, variances, state_steps, - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -349,7 +349,7 @@ def _single_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Grad scaling should occur outside of optimizer.step()") @@ -473,7 +473,7 @@ def _multi_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def adafactor( eps1: float, eps2: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adafactor algorithm computation. See :class:`~torch.optim.Adafactor` for details. diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 9b2c76700b35..ba97bc997937 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -33,7 +33,7 @@ def sparse_adam( beta2: float, lr: float, maximize: bool, -): +) -> None: r"""Functional API that performs Sparse Adam algorithm computation. See :class:`~torch.optim.SparseAdam` for details. diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index 7b7167a40fc1..5b7b9892daf3 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -141,7 +141,7 @@ def _init_group( params_with_grad: list[Tensor], grads: list[Tensor], muon_momentum_bufs: list[Tensor], - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -337,7 +337,7 @@ def muon( eps: float, adjust_lr_fn: Optional[str], has_complex: bool, -): +) -> None: r"""Functional API that performs Muon algorithm computation. See :class:`~torch.optim.Muon` for details. diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4a893026451a..75ac77790e30 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -38,7 +38,7 @@ def __init__( capturable: bool = False, maximize: bool = False, differentiable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -257,7 +257,7 @@ def _single_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( @@ -317,7 +317,7 @@ def _multi_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -427,7 +427,7 @@ def adadelta( eps: float, weight_decay: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adadelta algorithm computation. See :class:`~torch.optim.Adadelta` for details. diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 4d2523b2a16a..519900ab5da6 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -38,7 +38,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -116,7 +116,7 @@ def __setstate__(self, state): float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) ) - def share_memory(self): + def share_memory(self) -> None: """Calls tensor.share_memory_() on the state sum tensors.""" for group in self.param_groups: for p in group["params"]: @@ -261,7 +261,7 @@ def adagrad( lr_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. @@ -336,7 +336,7 @@ def _single_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -404,7 +404,7 @@ def _multi_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") if grad_scale is not None or found_inf is not None: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 5ceadccce86a..6b8fd5b7e70f 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -47,7 +47,7 @@ def __init__( differentiable: bool = False, fused: Optional[bool] = None, decoupled_weight_decay: bool = False, - ): + ) -> None: if isinstance(lr, Tensor): if foreach and not capturable: raise ValueError( @@ -365,7 +365,7 @@ def _single_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -572,7 +572,7 @@ def _multi_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if len(params) == 0: return @@ -925,7 +925,7 @@ def adam( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 76d784d6ea76..264451dbb409 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -239,7 +239,7 @@ def _single_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -319,7 +319,7 @@ def _multi_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -441,7 +441,7 @@ def adamax( beta2: float, lr: float, weight_decay: float, -): +) -> None: r"""Functional API that performs adamax algorithm computation. See :class:`~torch.optim.Adamax` for details. diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 0558cbddd883..2c968fabb698 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -33,7 +33,7 @@ def __init__( capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: super().__init__( params, lr, @@ -152,7 +152,7 @@ def adamw( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 0008694bda18..0af7f9b4e6f6 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -211,7 +211,7 @@ def _single_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -292,7 +292,7 @@ def _multi_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -442,7 +442,7 @@ def asgd( t0: float, alpha: float, weight_decay: float, -): +) -> None: r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index ae4b286ffa22..3d138f6a43f7 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -254,7 +254,7 @@ def __init__( tolerance_change: float = 1e-9, history_size: int = 100, line_search_fn: Optional[str] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -304,7 +304,7 @@ def _gather_flat_grad(self): views.append(view) return torch.cat(views, 0) - def _add_grad(self, step_size, update): + def _add_grad(self, step_size, update) -> None: offset = 0 for p in self._params: if torch.is_complex(p): @@ -319,7 +319,7 @@ def _add_grad(self, step_size, update): def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] - def _set_param(self, params_data): + def _set_param(self, params_data) -> None: for p, pdata in zip(self._params, params_data, strict=True): p.copy_(pdata) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 71dcb6129a8e..6426283e6542 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -89,7 +89,9 @@ def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]: ] -def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): +def _update_param_group_val( + param_group: dict[str, Any], key: str, val: float | Tensor +) -> None: """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ @@ -196,7 +198,7 @@ def state_dict(self) -> dict[str, Any]: key: value for key, value in self.__dict__.items() if key != "optimizer" } - def load_state_dict(self, state_dict: dict[str, Any]): + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: @@ -288,7 +290,7 @@ def step(self, epoch: Optional[int] = None) -> None: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2) self._update_lr(epoch) - def _update_lr(self, epoch: Optional[int] = None): + def _update_lr(self, epoch: Optional[int] = None) -> None: with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 @@ -339,7 +341,7 @@ def __exit__(self, type, value, traceback) -> None: class _initial_mode: - def __init__(self, o: LRScheduler): + def __init__(self, o: LRScheduler) -> None: self.o = o def __enter__(self): @@ -1180,7 +1182,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def recursive_undo(self, sched=None): + def recursive_undo(self, sched=None) -> None: """ Recursively undo any step performed by the initialisation of schedulers. @@ -1659,7 +1661,7 @@ def __init__( cooldown: int = 0, min_lr: Union[list[float], float] = 0, eps: float = 1e-8, - ): # noqa: D107 + ) -> None: # noqa: D107 if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor @@ -1691,7 +1693,7 @@ def __init__( ) self._reset() - def _reset(self): + def _reset(self) -> None: """Reset num_bad_epochs counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 @@ -1724,7 +1726,7 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over self._last_lr = _param_groups_val_list(self.optimizer, "lr") - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch) -> None: if len(self.optimizer.param_groups) != len(self.min_lrs): if self.default_min_lr is None: raise RuntimeError( @@ -1765,7 +1767,7 @@ def _is_better(self, a, best): # noqa: D102 else: # mode == 'max' and epsilon_mode == 'abs': return a > best + self.threshold - def _init_is_better(self, mode, threshold, threshold_mode): + def _init_is_better(self, mode, threshold, threshold_mode) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: @@ -1904,7 +1906,7 @@ def __init__( base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -1970,7 +1972,7 @@ def __init__( super().__init__(optimizer, last_epoch) self.base_lrs = base_lrs - def _init_scale_fn(self): + def _init_scale_fn(self) -> None: if self._scale_fn_custom is not None: return if self.mode == "triangular": @@ -2155,7 +2157,7 @@ def __init__( T_mult: int = 1, eta_min: float = 0.0, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): @@ -2407,7 +2409,7 @@ def __init__( final_div_factor: float = 1e4, three_phase: bool = False, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Validate optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 508648a65c14..f83cd4b85d02 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -44,7 +44,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -297,7 +297,7 @@ def _single_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -397,7 +397,7 @@ def _multi_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def nadam( weight_decay: float, momentum_decay: float, eps: float, -): +) -> None: r"""Functional API that performs NAdam algorithm computation. See :class:`~torch.optim.NAdam` for details. diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 6a336fa5bab7..c42ea3cfb02d 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -204,7 +204,7 @@ def _device_dtype_check_for_fused( ) -def _view_as_real(params, *state_and_grads): +def _view_as_real(params, *state_and_grads) -> None: for i, p in enumerate(params): if torch.is_complex(p): params[i] = torch.view_as_real(params[i]) diff --git a/torch/optim/radam.py b/torch/optim/radam.py index e13e6806e43a..db69bbb01a04 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -42,7 +42,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -270,7 +270,7 @@ def _single_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -377,7 +377,7 @@ def _multi_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -586,7 +586,7 @@ def radam( lr: float, weight_decay: float, eps: float, -): +) -> None: r"""Functional API that performs RAdam algorithm computation. See :class:`~torch.optim.RAdam` for details. diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 04981d517d1e..364068ecc9ab 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -41,7 +41,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -280,7 +280,7 @@ def _single_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -357,7 +357,7 @@ def _multi_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -495,7 +495,7 @@ def rmsprop( weight_decay: float, momentum: float, centered: bool, -): +) -> None: r"""Functional API that performs rmsprop algorithm computation. See :class:`~torch.optim.RMSProp` for details. diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 8ad7faf130e3..c9e1d5eabaee 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -235,7 +235,7 @@ def _single_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: for i, param in enumerate(params): grad = grads[i] grad = grad if not maximize else -grad @@ -306,7 +306,7 @@ def _multi_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -428,7 +428,7 @@ def rprop( step_size_max: float, etaminus: float, etaplus: float, -): +) -> None: r"""Functional API that performs rprop algorithm computation. See :class:`~torch.optim.Rprop` for details. diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 9c2c5a0eab3d..63c80d645cd0 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, differentiable: bool = False, fused: Optional[bool] = None, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if lr < 0.0: @@ -267,7 +267,7 @@ def sgd( dampening: float, nesterov: bool, maximize: bool, -): +) -> None: r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. @@ -333,7 +333,7 @@ def _single_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -394,7 +394,7 @@ def _multi_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index ca87e87ce867..ed58c93181ae 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -19,7 +19,7 @@ def __init__( betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 < lr: diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 1ab915d27cd6..ebe3e0702595 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -43,7 +43,9 @@ def get_ema_multi_avg_fn(decay=0.999): ) @torch.no_grad() - def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + def ema_update( + ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _ + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( ema_param_list[0] @@ -64,7 +66,7 @@ def swa_update( averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, num_averaged: Union[Tensor, int], - ): + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( averaged_param_list[0] @@ -227,7 +229,7 @@ def __init__( Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] ] = None, use_buffers=False, - ): # noqa: D107 + ) -> None: # noqa: D107 super().__init__() if avg_fn is not None and multi_avg_fn is not None: raise AssertionError( @@ -247,7 +249,7 @@ def forward(self, *args, **kwargs): """Forward pass.""" return self.module(*args, **kwargs) - def update_parameters(self, model: Module): + def update_parameters(self, model: Module) -> None: """Update model parameters.""" self_param = ( # pyrefly: ignore [bad-argument-type] @@ -329,7 +331,7 @@ def update_bn( loader: Iterable[Any], model: Module, device: Optional[Union[int, torch.device]] = None, -): +) -> None: r"""Update BatchNorm running_mean, running_var buffers in the model. It performs one pass over data in `loader` to estimate the activation @@ -367,7 +369,7 @@ def update_bn( was_training = model.training model.train() - for module in momenta.keys(): + for module in momenta: module.momentum = None for input in loader: @@ -378,7 +380,7 @@ def update_bn( model(input) - for bn_module in momenta.keys(): + for bn_module in momenta: bn_module.momentum = momenta[bn_module] model.train(was_training) @@ -434,7 +436,7 @@ def __init__( anneal_epochs=10, anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, - ): # noqa: D107 + ) -> None: # noqa: D107 swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True): group["swa_lr"] = swa_lr @@ -516,7 +518,7 @@ def get_lr(self): for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True) ] - def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): + def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None: self._anneal_strategy = anneal_strategy if anneal_strategy == "cos": self.anneal_func = self._cosine_anneal diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index 57a51ac41cfd..e07b20a83cc6 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -17,230 +17,5 @@ def is_stdlib_module(module: str) -> bool: def _get_stdlib_modules(): - if sys.version_info.major == 3: # noqa: UP036 - if sys.version_info.minor == 9: - return stdlib3_9 - if sys.version_info.minor >= 10: # noqa: YTT204 - return sys.stdlib_module_names # type: ignore[attr-defined] - elif sys.version_info.major > 3: # noqa: UP036 - return sys.stdlib_module_names # type: ignore[attr-defined] - - raise RuntimeError(f"Unsupported Python version: {sys.version_info}") - - -stdlib3_9 = { - "_thread", - "abc", - "aifc", - "argparse", - "array", - "ast", - "asynchat", - "asyncio", - "asyncore", - "atexit", - "audioop", - "base64", - "bdb", - "binascii", - "binhex", - "bisect", - "builtins", - "bz2", - "cProfile", - "calendar", - "cgi", - "cgitb", - "chunk", - "cmath", - "cmd", - "code", - "codecs", - "codeop", - "collections", - "colorsys", - "compileall", - "concurrent", - "configparser", - "contextlib", - "contextvars", - "copy", - "copyreg", - "crypt", - "csv", - "ctypes", - "curses", - "dataclasses", - "datetime", - "dbm", - "decimal", - "difflib", - "dis", - "distutils", - "doctest", - "email", - "encodings", - "ensurepip", - "enum", - "errno", - "faulthandler", - "fcntl", - "filecmp", - "fileinput", - "fnmatch", - "formatter", - "fractions", - "ftplib", - "functools", - "gc", - "getopt", - "getpass", - "gettext", - "glob", - "graphlib", - "grp", - "gzip", - "hashlib", - "heapq", - "hmac", - "html", - "http", - "imaplib", - "imghdr", - "imp", - "importlib", - "inspect", - "io", - "ipaddress", - "itertools", - "json", - "keyword", - "lib2to3", - "linecache", - "locale", - "logging", - "lzma", - "mailbox", - "mailcap", - "marshal", - "math", - "mimetypes", - "mmap", - "modulefinder", - "msilib", - "msvcrt", - "multiprocessing", - "netrc", - "nis", - "nntplib", - "ntpath", - "numbers", - "operator", - "optparse", - "os", - "ossaudiodev", - "parser", - "pathlib", - "pdb", - "pickle", - "pickletools", - "pipes", - "pkgutil", - "platform", - "plistlib", - "poplib", - "posix", - "posixpath", - "pprint", - "profile", - "pstats", - "pty", - "pwd", - "py_compile", - "pyclbr", - "pydoc", - "queue", - "quopri", - "random", - "re", - "readline", - "reprlib", - "resource", - "rlcompleter", - "runpy", - "sched", - "secrets", - "select", - "selectors", - "shelve", - "shlex", - "shutil", - "signal", - "site", - "smtpd", - "smtplib", - "sndhdr", - "socket", - "socketserver", - "spwd", - "sqlite3", - "sre", - "sre_compile", - "sre_constants", - "sre_parse", - "ssl", - "stat", - "statistics", - "string", - "stringprep", - "struct", - "subprocess", - "sunau", - "symbol", - "symtable", - "sys", - "sysconfig", - "syslog", - "tabnanny", - "tarfile", - "telnetlib", - "tempfile", - "termios", - "test", - "textwrap", - "threading", - "time", - "timeit", - "tkinter", - "token", - "tokenize", - "trace", - "traceback", - "tracemalloc", - "tty", - "turtle", - "turtledemo", - "types", - "typing", - "unicodedata", - "unittest", - "urllib", - "uu", - "uuid", - "venv", - "warnings", - "wave", - "weakref", - "webbrowser", - "winreg", - "winsound", - "wsgiref", - "xdrlib", - "xml", - "xmlrpc", - "zipapp", - "zipfile", - "zipimport", - "zlib", - "zoneinfo", -} + assert sys.version_info >= (3, 10) + return sys.stdlib_module_names diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 3f21ce81171d..dfa83f7467cd 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -711,7 +711,7 @@ def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]: events: list[tuple[int, Action, TensorAndID]] = [ (-1, Action.PREEXISTING, (key, version)) - for key, version in snapshot.keys() + for key, version in snapshot if (key, True) not in allocation_times and version == 0 ] @@ -938,7 +938,7 @@ def _set_parameters_using_data_flow(self) -> None: parameter_keys = {key.id for key, _ in candidate_parameters} parameter_keys &= self._any_version_depends_on_gradient() - for key, _ in snapshot.keys(): + for key, _ in snapshot: if key.id in parameter_keys: self._categories.set_by_id(key, Category.PARAMETER) diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c..2c575b06509e 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -103,7 +103,7 @@ def __init__(self, prof: profile) -> None: self.metrics: dict[EventKey, EventMetrics] = {} self.compute_self_time() self.event_keys = sorted( - (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + self.metrics.keys(), key=lambda x: x.event.start_time_ns ) self.events = [e.event for e in self.event_keys] self.cuda_events: list[_KinetoEvent] = [] @@ -265,7 +265,7 @@ def compute_idle_time(self) -> None: idle_intervals.append(Interval(idle_start, data_point.start)) idle = False - event_list = [e.event for e in self.metrics.keys()] + event_list = [e.event for e in self.metrics] for event in event_list: self.metrics[EventKey(event)].idle_time_ns = EventKey( event @@ -316,7 +316,7 @@ def rank_events(self, length): # Filter out events that are not in the decrease interval event_list = [ event - for event in self.metrics.keys() + for event in self.metrics if event.intervals_overlap(decrease_interval) ] if event_list: @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index ee0ea85e1694..893b4078cb9c 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -210,7 +210,8 @@ def prepare_trace(self) -> None: def start_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.start() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before starting trace") self.profiler._start_trace() if self.profile_memory: @@ -256,7 +257,8 @@ def start_trace(self) -> None: def stop_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.stop() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before stopping trace") self.profiler.__exit__(None, None, None) def export_chrome_trace(self, path: str): @@ -264,7 +266,10 @@ def export_chrome_trace(self, path: str): Exports the collected trace in Chrome JSON format. If kineto is enabled, only last cycle in schedule is exported. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before exporting chrome trace" + ) if path.endswith(".gz"): fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) fp.close() @@ -284,7 +289,8 @@ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): path (str): save stacks file to this location; metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before exporting stacks") return self.profiler.export_stacks(path, metric) def toggle_collection_dynamic( @@ -316,7 +322,7 @@ def toggle_collection_dynamic( print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) """ - if not self.profiler: + if self.profiler is None: return self.profiler.toggle_collection_dynamic(enable, activities) @@ -333,7 +339,10 @@ def key_averages( To use shape/stack functionality make sure to set record_shapes/with_stack when creating profiler context manager. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before getting key averages" + ) return self.profiler.key_averages( group_by_input_shape, group_by_stack_n, group_by_overload_name ) @@ -343,7 +352,8 @@ def events(self): Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before accessing events") return self.profiler.function_events def add_metadata(self, key: str, value: str) -> None: @@ -395,7 +405,10 @@ def _memory_profile(self) -> MemoryProfile: if missing: raise ValueError(f"{', '.join(missing)} required for memory profiling.") - assert self.profiler is not None and self.profiler.kineto_results is not None + if self.profiler is None or self.profiler.kineto_results is None: + raise AssertionError( + "Profiler and kineto_results must be initialized for memory profiling" + ) return MemoryProfile(self.profiler.kineto_results) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: @@ -485,7 +498,8 @@ def schedule( """ def schedule_fn(step: int) -> ProfilerAction: - assert step >= 0 + if step < 0: + raise AssertionError(f"Step must be non-negative. Got {step}.") if step < skip_first: return ProfilerAction.NONE else: @@ -508,9 +522,11 @@ def schedule_fn(step: int) -> ProfilerAction: else ProfilerAction.RECORD_AND_SAVE ) - assert ( - wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 - ), "Invalid profiler schedule arguments" + if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0: + raise AssertionError( + f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), " + f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)." + ) if warmup == 0: warn( "Profiler won't be using warmup, this can skew profiler results", @@ -717,7 +733,8 @@ def __init__( activities_set.add(ProfilerActivity.CUDA) elif ProfilerActivity.CUDA in activities_set: activities_set.remove(ProfilerActivity.CUDA) - assert len(activities_set) > 0, "No valid profiler activities found" + if len(activities_set) == 0: + raise AssertionError("No valid profiler activities found") super().__init__( activities=activities, diff --git a/torch/serialization.py b/torch/serialization.py index ce5a74d92384..ffa77cec732e 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1250,7 +1250,7 @@ def persistent_id(self, obj): zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) # Write each tensor to a file named tensor/the_tensor_key in the zip archive - for key in serialized_storages.keys(): + for key in serialized_storages: name = f"data/{key}" storage = serialized_storages[key] num_bytes = storage.nbytes() @@ -1494,7 +1494,7 @@ def _get_wo_message(message: str) -> str: _check_dill_version(pickle_module) - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" with _open_file_like(f, "rb") as opened_file: diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 18384b311b93..91f09adf9e81 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -238,6 +238,47 @@ def wrapper(*args, **kwargs): return decorator +def requires_world_size(n: int): + """ + Decorator to request a specific world size for a test. The test harness can + read this attribute to set the number of ranks to spawn. If there are fewer + than `n` CUDA devices available, the test should be skipped by the harness. + + Usage: + @require_world_size(3) + def test_something(self): + ... + """ + + def decorator(func): + func._required_world_size = n + available = torch.cuda.device_count() + return unittest.skipUnless( + available >= n, f"requires {n} GPUs, found {available}" + )(func) + + return decorator + + +def get_required_world_size(obj: Any, default: int) -> int: + """ + Returns the requested world size for the currently running unittest method on `obj` + if annotated via `@require_world_size(n)`, else returns `default`. + """ + try: + # Try MultiProcessTestCase helper first, then unittest fallback + test_name = ( + obj._current_test_name() # type: ignore[attr-defined] + if hasattr(obj, "_current_test_name") and callable(obj._current_test_name) + else obj._testMethodName + ) + fn = getattr(obj, test_name) + value = fn._required_world_size + return int(value) + except Exception: + return default + + # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): @@ -367,6 +408,13 @@ def requires_nccl_version(version, msg): ) +def requires_nccl_shrink(): + """ + Require NCCL shrink support (NCCL available and version >= 2.27). + """ + return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group") + + def requires_nccl(): return skip_but_pass_in_sandcastle_if( not c10d.is_nccl_available(), diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92f212a3c650..0413c9bf6b6e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14311,7 +14311,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('max', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14320,7 +14320,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True), OpInfo('max', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), supports_out=True, supports_forward_ad=True, @@ -14465,7 +14465,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False,), OpInfo('min', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14474,7 +14474,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('min', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14784,7 +14784,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), - dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), decorators=(onlyNativeDeviceTypes,), supports_autograd=False, @@ -21126,7 +21126,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amax), skips=( # FIXME: reduces all dimensions when dim=[] @@ -21141,7 +21141,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amin), skips=( # FIXME: reduces all dimensions when dim=[] diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 0c26738c2f52..00572f969138 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -333,7 +333,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', 'mtia', torch._C._get_privateuse1_backend_name()) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 17140f40684d..f4afca4bd180 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -814,3 +814,7 @@ def map_local_tensor_for_rank(tensor, rank, func): @maybe_run_for_local_tensor def map_local_for_rank(rank, func): return func(rank) + + +def reduce_local_int(val, func): + return func(val.node._local_ints) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a14f670d788b..8cb9c929d854 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7050,8 +7050,8 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) - self.assertTrue("in_split_size" in attrs.keys()) - self.assertTrue("out_split_size" in attrs.keys()) + self.assertTrue("in_split_size" in attrs) + self.assertTrue("out_split_size" in attrs) self.assertEqual(attrs.get("global_rank_start", -1), 0) self.assertEqual(attrs.get("global_rank_stride", -1), 1) @@ -9306,7 +9306,7 @@ def get_loss(model_output): "tuple": tuple, "dict": dict, } - for output_type in type_mapping.keys(): + for output_type in type_mapping: for _ in range(6): out = model(inp, output_type=output_type) loss = get_loss(out) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index b7c0dd17a116..21464e514742 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3282,7 +3282,7 @@ def test_debug_info(self): expected.update(autograd_info) # NB: Key ordering is only preserved in python 3.6+. So here, we # manually check keys are equal. - for key in expected.keys(): + for key in expected: self.assertIn(key, info.keys()) for key in info.keys(): diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index fc6cfa8cf7f4..3b38661c69b8 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -103,6 +103,7 @@ def f2(x, y0, y1): "dynamo_bypassing_wrapper", # TODO(soulitzer) "foreach_map", "aoti_call_delegate", + "print", ] torch.library.define( @@ -153,6 +154,7 @@ def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs def fn_for_invoke_subgraph(x): return torch.sin(x) + def simple_invoke_subgraph(x): return fn_for_invoke_subgraph(x) @@ -202,6 +204,7 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) + def simple_while_loop_stack_output(iter_t, x): def cond_fn(iter_t, x): return iter_t > 0 @@ -209,7 +212,9 @@ def cond_fn(iter_t, x): def body_fn(iter_t, x): return iter_t - 1, x.cos() - return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple()) + return torch._higher_order_ops.while_loop_stack_output( + cond_fn, body_fn, (iter_t, x), tuple() + ) def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): @@ -226,18 +231,21 @@ def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): def simple_local_map_hop(inp1, inp2): def body_gm(inp1, inp2): return inp1.cos() + inp2.sin() + gm = torch.fx.symbolic_trace(body_gm) assert torch.distributed.is_available() from torch.distributed.tensor.placement_types import Replicate + gm.meta["local_map_kwargs"] = { "in_placements": (Replicate(), Replicate(), Replicate()), - "out_placements": ((Replicate(), Replicate(), Replicate()),) + "out_placements": ((Replicate(), Replicate(), Replicate()),), } # TODO: Dynamo would rewrite this op differently return torch._higher_order_ops.local_map_hop(gm, inp1, inp2) + def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -249,7 +257,6 @@ def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): def simple_scan(init, xs): - def combine_fn(carry, x): result = carry @ x + x return result, carry.clone() @@ -264,15 +271,14 @@ def simple_invoke_quant(x): def fn(x, y): return (torch.sin(x) * y,) - return quant_tracer(fn, x, x)[0] * 2. + return quant_tracer(fn, x, x)[0] * 2.0 def simple_invoke_quant_packed(x): def fn(x): return (torch.sin(x),) - return invoke_quant_packed(fn, x)[0] * 2. - + return invoke_quant_packed(fn, x)[0] * 2.0 hop_db = [ @@ -496,6 +502,11 @@ def fn(x): DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), ), - decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")], + decorators=[ + onlyCUDA, + unittest.skipIf( + not torch.distributed.is_available(), "requires distributed build" + ), + ], ), ] diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index bd11e01a8025..6bd34c812d64 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -33,6 +33,10 @@ OrderedSet, ) from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._helion import has_helion +from torch.utils._pallas import has_pallas +from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -43,9 +47,6 @@ LazyVal, TestCase, ) -from torch.utils._config_module import ConfigModule -from torch.utils._helion import has_helion -from torch.utils._triton import has_triton log: logging.Logger = logging.getLogger(__name__) @@ -67,6 +68,8 @@ def test_cpu(): HAS_TRITON = has_triton() +HAS_PALLAS = has_pallas() + HAS_HELION = has_helion() if HAS_TRITON: diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 12ba497efd79..f302a10b8338 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -692,7 +692,7 @@ def __enter__(self) -> None: raise AssertionError( "prior should be empty when entering ConfigPatch" ) - for key in self.changes.keys(): + for key in self.changes: # KeyError on invalid entry prior[key] = config.__getattr__(key) for k, v in self.changes.items(): diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c1..897279bd39b1 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 09435aa07e68..5a6ee246abf7 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,7 +2,9 @@ import contextlib import functools import traceback -from typing import Any, Callable, Optional, TYPE_CHECKING +import weakref +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -14,6 +16,7 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -56,29 +59,48 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -def _tensor_debug_string(tensor, attributes) -> str: +class TensorIdTracker: + def __init__(self): + self.tensor_memo: dict[WeakIdRef, int] = {} + self.next_tensor_id = 0 + + def _id(self, tensor) -> int: + with torch._C._DisablePythonDispatcher(): + o = WeakIdRef(tensor) + + def del_memo(): + self.tensor_memo.pop(o, None) + + weakref.finalize(tensor, del_memo) + if o not in self.tensor_memo: + self.tensor_memo[o] = self.next_tensor_id + self.next_tensor_id += 1 + return self.tensor_memo[o] + + +def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - + id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft: {tensor_debug_str}" + return f"ft{id_str}: {tensor_debug_str}" else: - return f"t: {tensor_debug_str}" + return f"t{id_str}: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes) -> str: +def _arg_to_str(arg, attributes, tensor_memo=None) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes) + return _tensor_debug_string(x, attributes, tensor_memo) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -144,8 +166,11 @@ def __init__( # results from dispatch hooks self.record = record self.log = log + self.output_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -153,6 +178,18 @@ def stringify_args(self, attributes: list[str]) -> None: "Subclasses must implement stringify_args(), even if no-op" ) + def stringify_output( + self, + output: Any, + attributes: list[str], + tensor_memo: Optional[TensorIdTracker] = None, + ) -> None: + """Store stringified version of call output in self.output_str""" + if tree_all(lambda x: x is None, output): + return + output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) + self.output_str = f" -> {str(output_str)}" + def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -179,11 +216,16 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.args_str = ", ".join( + _arg_to_str(arg, attributes, tensor_memo) for arg in self.args + ) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes, tensor_memo)}" + for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -215,6 +257,8 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" + if self.output_str: + base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -247,8 +291,10 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -263,7 +309,11 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + + base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + if self.output_str: + base_str += self.output_str + return base_str def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -288,7 +338,9 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -341,6 +393,8 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, + record_output=False, + record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -378,8 +432,24 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace + # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) + self.record_output: bool = record_output + + # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. + self.record_ids: bool = record_ids + + self.reset() + + def reset(self): self.operators = [] self.call_depth = 0 + self._tensor_memo = TensorIdTracker() + self._output_info: dict[int, object] = {} + + def _track_op_output(self, op_index, result): + """Assign IDs to output tensors and store in output_info""" + # self._track_tensor_ids(result) + self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the β€œeager” backend @@ -390,20 +460,35 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args(self.record_tensor_attributes) + call.stringify_args( + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) self.operators.append(call) + def _record_call_output(self, call, output): + if not self.record_output: + return + call.stringify_output( + output, + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self._record_call( - _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace ) + self._record_call(call) try: self.call_depth += 1 - return func(*args, **kwargs) + result = func(*args, **kwargs) + self._record_call_output(call, result) + return result finally: self.call_depth -= 1 @@ -445,13 +530,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: + self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.operators = [] - self.call_depth = 0 + self.reset() if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py new file mode 100644 index 000000000000..25cc635dbb17 --- /dev/null +++ b/torch/utils/_pallas.py @@ -0,0 +1,82 @@ +import functools + +import torch + + +@functools.cache +def has_jax_package() -> bool: + """Check if JAX is installed.""" + try: + import jax # noqa: F401 # type: ignore[import-not-found] + + return True + except ImportError: + return False + + +@functools.cache +def has_pallas_package() -> bool: + """Check if Pallas (JAX experimental) is available.""" + if not has_jax_package(): + return False + try: + from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found] + pallas as pl, + ) + + return True + except ImportError: + return False + + +@functools.cache +def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]: + """Get JAX version as (major, minor, patch) tuple.""" + try: + import jax # type: ignore[import-not-found] + + version_parts = jax.__version__.split(".") + major, minor, patch = (int(v) for v in version_parts[:3]) + return (major, minor, patch) + except (ImportError, ValueError, AttributeError): + return fallback + + +@functools.cache +def has_jax_cuda_backend() -> bool: + """Check if JAX has CUDA backend support.""" + if not has_jax_package(): + return False + try: + import jax # type: ignore[import-not-found] + + # Check if CUDA backend is available + devices = jax.devices("gpu") + return len(devices) > 0 + except Exception: + return False + + +@functools.cache +def has_pallas() -> bool: + """ + Check if Pallas backend is fully available for use. + + Requirements: + - JAX package installed + - Pallas (jax.experimental.pallas) available + - CUDA backend available (for GPU support) + """ + if not has_pallas_package(): + return False + + # Only enable Pallas if CUDA is available + # (Pallas primarily targets GPU workloads) + if not torch.cuda.is_available(): + return False + + # Check if JAX has GPU/CUDA backend + if not has_jax_cuda_backend(): + return False + + return True diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 56704bb3f802..147340f58d66 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,10 +36,11 @@ Optional, overload, Protocol, + TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3f..915d0e5461f1 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 3b8b62cfde6d..a643314f3b9c 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -803,14 +803,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): def replace_nones(dct, replacement="Could not collect"): - for key in dct.keys(): + for key in dct: if dct[key] is not None: continue dct[key] = replacement return dct def replace_bools(dct, true="Yes", false="No"): - for key in dct.keys(): + for key in dct: if dct[key] is True: dct[key] = true elif dct[key] is False: diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 1ce1c9c07196..2e3bb18c80bb 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -149,7 +149,7 @@ def _collate_helper(conversion, item): tuple_names: list = [] tuple_values: list = [] - for name in conversion.keys(): + for name in conversion: if name not in columns_name: raise RuntimeError("Conversion keys mismatch") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 865feb9953e3..a289bdb5e094 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -234,7 +234,7 @@ def _remove_biggest_key(self): biggest_key = None biggest_size = 0 result_to_yield = None - for findkey in self.buffer_elements.keys(): + for findkey in self.buffer_elements: if len(self.buffer_elements[findkey]) > biggest_size: biggest_size = len(self.buffer_elements[findkey]) biggest_key = findkey diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index f36382cb42e1..1b6a2bb9bb66 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -334,7 +334,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): # pyrefly: ignore [missing-attribute] ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) - mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict] exp = Experiment(hparam_infos=hps, metric_infos=mts) diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 4fab33dc7ff0..0f533ae5b0f5 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -424,7 +424,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag if self.all_writers is None: raise AssertionError("self.all_writers is None") - if fw_tag in self.all_writers.keys(): + if fw_tag in self.all_writers: fw = self.all_writers[fw_tag] else: fw = FileWriter( diff --git a/torch/utils/viz/MemoryViz.js b/torch/utils/viz/MemoryViz.js index 09f8c444f600..dfeae36cebab 100644 --- a/torch/utils/viz/MemoryViz.js +++ b/torch/utils/viz/MemoryViz.js @@ -806,7 +806,29 @@ function format_frames(frames) { } const frame_strings = frames .filter(frameFilter) - .map(f => `${f.filename}:${f.line}:${f.name}`); + .map(f => { + let frame_str = `${f.filename}:${f.line}:${f.name}`; + + // Add FX debug information if available + if (f.fx_node_op || f.fx_node_name || f.fx_node_target) { + const fx_parts = []; + if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`); + if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`); + if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`); + frame_str += `\n >> FX: ${fx_parts.join(', ')}`; + } + + if (f.fx_original_trace) { + frame_str += `\n >> Original Model Code:`; + const original_lines = f.fx_original_trace.trim().split('\n'); + // Show all lines of the original trace + for (const line of original_lines) { + frame_str += `\n ${line}`; + } + } + + return frame_str; + }); return elideRepeats(frame_strings).join('\n'); } diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 07097010f8f2..c9f1b660f02c 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -287,8 +287,7 @@ def error_on_missing_kernels( expected_backend_native_funcs: list[NativeFunction] = [ f for f in native_functions - if f.func.name in expected_backend_op_names.keys() - and f.func.name not in full_codegen + if f.func.name in expected_backend_op_names and f.func.name not in full_codegen ] expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( list