diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index b7e61115e37d6..748608005e622 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -125,10 +125,10 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks) + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks) CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 + GCC_VERSION=11 VISION=yes KATEX=yes UCX_COMMIT=${_UCX_COMMIT} @@ -146,16 +146,6 @@ case "$tag" in UCC_COMMIT=${_UCC_COMMIT} TRITON=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9) - CUDA_VERSION=12.8.1 - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=9 - VISION=yes - KATEX=yes - UCX_COMMIT=${_UCX_COMMIT} - UCC_COMMIT=${_UCC_COMMIT} - TRITON=yes - ;; pytorch-linux-jammy-py3-clang12-onnx) ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=12 diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index f3636071714f8..242cbaafa059e 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -402,3 +402,6 @@ scikit-build==0.18.1 pyre-extensions==0.0.32 tabulate==0.9.0 #Description: These package are needed to build FBGEMM and torchrec on PyTorch CI + +Jinja2==3.1.6 +#Description: required for torch.distributed.debug diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index b0c607659c72d..7ad10ca946215 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -100,6 +100,347 @@ def check_lib_statically_linked_libstdc_cxx_abi_symbols(lib: str) -> None: ) +def _compile_and_extract_symbols( + cpp_content: str, compile_flags: list[str], exclude_list: list[str] | None = None +) -> list[str]: + """ + Helper to compile a C++ file and extract all symbols. + + Args: + cpp_content: C++ source code to compile + compile_flags: Compilation flags + exclude_list: List of symbol names to exclude. Defaults to ["main"]. + + Returns: + List of all symbols found in the object file (excluding those in exclude_list). + """ + import subprocess + import tempfile + + if exclude_list is None: + exclude_list = ["main"] + + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode != 0: + raise RuntimeError(f"Compilation failed: {result.stderr}") + + symbols = get_symbols(str(obj_file)) + + # Return all symbol names, excluding those in the exclude list + return [name for _addr, _stype, name in symbols if name not in exclude_list] + + +def check_stable_only_symbols(install_root: Path) -> None: + """ + Test TORCH_STABLE_ONLY and TORCH_TARGET_VERSION by compiling test code. + + This approach tests: + 1. WITHOUT macros -> many torch symbols exposed (compilation succeeds) + 2. WITH TORCH_STABLE_ONLY -> compilation fails with #error directive + 3. WITH TORCH_TARGET_VERSION -> compilation fails with #error directive + 4. WITH both macros -> compilation fails with #error directive + """ + import subprocess + import tempfile + + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + test_cpp_content = """ +// Main torch C++ API headers +#include +#include + +// ATen tensor library +#include + +// Core c10 headers (commonly used) +#include +#include +#include +#include +#include + +int main() { return 0; } +""" + + base_compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", # Compile only, don't link + ] + + # Compile WITHOUT any macros - should succeed + symbols_without = _compile_and_extract_symbols( + cpp_content=test_cpp_content, + compile_flags=base_compile_flags, + ) + + # We expect constexpr symbols, inline functions used by other headers etc. + # to produce symbols + num_symbols_without = len(symbols_without) + print(f"Found {num_symbols_without} symbols without any macros defined") + assert num_symbols_without != 0, ( + "Expected a non-zero number of symbols without any macros" + ) + + # Helper to verify compilation fails with expected error + def _expect_compilation_failure(compile_flags: list[str], macro_name: str) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + cpp_file = tmppath / "test.cpp" + obj_file = tmppath / "test.o" + + cpp_file.write_text(test_cpp_content) + + result = subprocess.run( + compile_flags + [str(cpp_file), "-o", str(obj_file)], + capture_output=True, + text=True, + timeout=60, + ) + + if result.returncode == 0: + raise RuntimeError( + f"Expected compilation to fail with {macro_name} defined, but it succeeded" + ) + + stderr = result.stderr + expected_error_msg = ( + "This file should not be included when either TORCH_STABLE_ONLY " + "or TORCH_TARGET_VERSION is defined." + ) + + if expected_error_msg not in stderr: + raise RuntimeError( + f"Expected error message to contain:\n '{expected_error_msg}'\n" + f"but got:\n{stderr[:1000]}" + ) + + print(f"Compilation correctly failed with {macro_name} defined") + + compile_flags_with_stable_only = base_compile_flags + ["-DTORCH_STABLE_ONLY"] + _expect_compilation_failure(compile_flags_with_stable_only, "TORCH_STABLE_ONLY") + + compile_flags_with_target_version = base_compile_flags + [ + "-DTORCH_TARGET_VERSION=1" + ] + _expect_compilation_failure( + compile_flags_with_target_version, "TORCH_TARGET_VERSION" + ) + + compile_flags_with_both = base_compile_flags + [ + "-DTORCH_STABLE_ONLY", + "-DTORCH_TARGET_VERSION=1", + ] + _expect_compilation_failure(compile_flags_with_both, "both macros") + + +def check_stable_api_symbols(install_root: Path) -> None: + """ + Test that stable API headers still expose symbols with TORCH_STABLE_ONLY. + The torch/csrc/stable/c/shim.h header is tested in check_stable_c_shim_symbols + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + stable_dir = include_dir / "torch" / "csrc" / "stable" + assert stable_dir.exists(), f"Expected {stable_dir} to be present" + + stable_headers = list(stable_dir.rglob("*.h")) + if not stable_headers: + raise RuntimeError("Could not find any stable headers") + + includes = [] + for header in stable_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_stable_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable = _compile_and_extract_symbols( + cpp_content=test_stable_content, + compile_flags=compile_flags, + ) + num_symbols_stable = len(symbols_stable) + print(f"Found {num_symbols_stable} symbols in torch/csrc/stable") + assert num_symbols_stable > 0, ( + f"Expected stable headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable} symbols" + ) + + +def check_headeronly_symbols(install_root: Path) -> None: + """ + Test that header-only utility headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Find all headers in torch/headeronly + headeronly_dir = include_dir / "torch" / "headeronly" + assert headeronly_dir.exists(), f"Expected {headeronly_dir} to be present" + headeronly_headers = list(headeronly_dir.rglob("*.h")) + if not headeronly_headers: + raise RuntimeError("Could not find any headeronly headers") + + # Filter out platform-specific headers that may not compile everywhere + platform_specific_keywords = [ + "cpu/vec", + ] + + filtered_headers = [] + for header in headeronly_headers: + rel_path = header.relative_to(include_dir).as_posix() + if not any( + keyword in rel_path.lower() for keyword in platform_specific_keywords + ): + filtered_headers.append(header) + + includes = [] + for header in filtered_headers: + rel_path = header.relative_to(include_dir) + includes.append(f"#include <{rel_path.as_posix()}>") + + includes_str = "\n".join(includes) + test_headeronly_content = f""" +{includes_str} +int main() {{ return 0; }} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_headeronly = _compile_and_extract_symbols( + cpp_content=test_headeronly_content, + compile_flags=compile_flags, + ) + num_symbols_headeronly = len(symbols_headeronly) + print(f"Found {num_symbols_headeronly} symbols in torch/headeronly") + assert num_symbols_headeronly > 0, ( + f"Expected headeronly headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_headeronly} symbols" + ) + + +def check_aoti_shim_symbols(install_root: Path) -> None: + """ + Test that AOTI shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_shim_content = """ +#include +int main() { + int32_t (*fp1)() = &aoti_torch_device_type_cpu; + int32_t (*fp2)() = &aoti_torch_dtype_float32; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_shim = _compile_and_extract_symbols( + cpp_content=test_shim_content, + compile_flags=compile_flags, + ) + num_symbols_shim = len(symbols_shim) + assert num_symbols_shim > 0, ( + f"Expected shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_shim} symbols" + ) + + +def check_stable_c_shim_symbols(install_root: Path) -> None: + """ + Test that stable C shim headers still expose symbols with TORCH_STABLE_ONLY. + """ + include_dir = install_root / "include" + assert include_dir.exists(), f"Expected {include_dir} to be present" + + # Check if the stable C shim exists + stable_shim = include_dir / "torch" / "csrc" / "stable" / "c" / "shim.h" + if not stable_shim.exists(): + raise RuntimeError("Could not find stable c shim") + + # There are no constexpr symbols etc., so we need to actually use functions + # so that some symbols are found. + test_stable_shim_content = """ +#include +int main() { + // Reference stable C API functions to create undefined symbols + AOTITorchError (*fp1)(const char*, uint32_t*, int32_t*) = &torch_parse_device_string; + AOTITorchError (*fp2)(uint32_t*) = &torch_get_num_threads; + (void)fp1; (void)fp2; + return 0; +} +""" + + compile_flags = [ + "g++", + "-std=c++17", + f"-I{include_dir}", + f"-I{include_dir}/torch/csrc/api/include", + "-c", + "-DTORCH_STABLE_ONLY", + ] + + symbols_stable_shim = _compile_and_extract_symbols( + cpp_content=test_stable_shim_content, + compile_flags=compile_flags, + ) + num_symbols_stable_shim = len(symbols_stable_shim) + assert num_symbols_stable_shim > 0, ( + f"Expected stable C shim headers to expose symbols with TORCH_STABLE_ONLY, " + f"but found {num_symbols_stable_shim} symbols" + ) + + def check_lib_symbols_for_abi_correctness(lib: str) -> None: print(f"lib: {lib}") cxx11_symbols = grep_symbols(lib, LIBTORCH_CXX11_PATTERNS) @@ -129,6 +470,13 @@ def main() -> None: check_lib_symbols_for_abi_correctness(libtorch_cpu_path) check_lib_statically_linked_libstdc_cxx_abi_symbols(libtorch_cpu_path) + # Check symbols when TORCH_STABLE_ONLY is defined + check_stable_only_symbols(install_root) + check_stable_api_symbols(install_root) + check_headeronly_symbols(install_root) + check_aoti_shim_symbols(install_root) + check_stable_c_shim_symbols(install_root) + if __name__ == "__main__": main() diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 64ee992f566b7..c3b209c216014 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -2d82dc5caa336d179d9b46ac4a0fb8c43d84c5cc +617079d944b0e72632311c30ae2bbdf1168b901e diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index f7df4335cb5b6..d69db191b9464 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -50,6 +50,7 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { "12.6": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | " @@ -67,6 +68,7 @@ "nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux'" ), "12.8": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | " @@ -84,6 +86,7 @@ "nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux'" ), "12.9": ( + "cuda-bindings==12.9.4; platform_system == 'Linux' | " "nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | " "nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | " "nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | " @@ -101,6 +104,7 @@ "nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux'" ), "13.0": ( + "cuda-bindings==13.0.3; platform_system == 'Linux' | " "nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | " "nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | " "nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | " diff --git a/.github/workflows/attention_op_microbenchmark.yml b/.github/workflows/attention_op_microbenchmark.yml index e01bc49621dcf..eec4d21fe2616 100644 --- a/.github/workflows/attention_op_microbenchmark.yml +++ b/.github/workflows/attention_op_microbenchmark.yml @@ -23,7 +23,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -39,7 +39,7 @@ jobs: needs: attn-microbenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.attn-microbenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }} secrets: inherit @@ -51,7 +51,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -66,7 +66,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 408a8f0000504..fa1f083800fe0 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -52,8 +52,7 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks, - pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, + pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, pytorch-linux-jammy-py3.11-clang12, @@ -75,7 +74,8 @@ jobs: pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, - pytorch-linux-jammy-py3-clang12-executorch, + # TODO: Re-enable me when docker pin update happens + # pytorch-linux-jammy-py3-clang12-executorch, pytorch-linux-jammy-py3.12-triton-cpu, pytorch-linux-noble-riscv64-py3.12-gcc14 ] diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 78d38de3ac69a..c973656018944 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -6,10 +6,9 @@ on: branches: [main, release] types: - completed - workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }} + group: ${{ github.workflow }}-${{ github.event.workflow_run.head_branch }} cancel-in-progress: true permissions: @@ -50,9 +49,10 @@ jobs: matrix: runner: [linux.rocm.gfx942.docker-cache] docker-image: [ - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", - "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" + "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}" + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}", + #"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}" ] runs-on: "${{ matrix.runner }}" steps: diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index b8a6403faffbd..6a22e14af09b7 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -132,7 +132,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -178,7 +178,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -224,7 +224,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -270,7 +270,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -381,7 +381,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -427,7 +427,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -473,7 +473,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -519,7 +519,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -630,7 +630,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -676,7 +676,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -722,7 +722,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -768,7 +768,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -879,7 +879,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -925,7 +925,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -971,7 +971,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1017,7 +1017,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1128,7 +1128,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1174,7 +1174,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1220,7 +1220,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1266,7 +1266,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1377,7 +1377,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1423,7 +1423,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1469,7 +1469,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1515,7 +1515,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1626,7 +1626,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1672,7 +1672,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_8 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1718,7 +1718,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-12_9 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1764,7 +1764,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 21c1d5caa3829..a5f4e85ca58c1 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -127,7 +127,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_6-test: # Testing @@ -193,7 +193,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_8-test: # Testing @@ -259,7 +259,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_9-test: # Testing @@ -325,7 +325,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda13_0-test: # Testing @@ -793,7 +793,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_6-test: # Testing @@ -859,7 +859,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_8-test: # Testing @@ -925,7 +925,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_9-test: # Testing @@ -991,7 +991,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda13_0-test: # Testing @@ -1459,7 +1459,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_6-test: # Testing @@ -1525,7 +1525,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_8-test: # Testing @@ -1591,7 +1591,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_9-test: # Testing @@ -1657,7 +1657,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda13_0-test: # Testing @@ -2125,7 +2125,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_6-test: # Testing @@ -2191,7 +2191,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_8-test: # Testing @@ -2257,7 +2257,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_9-test: # Testing @@ -2323,7 +2323,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda13_0-test: # Testing @@ -2791,7 +2791,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_6-test: # Testing @@ -2857,7 +2857,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_8-test: # Testing @@ -2923,7 +2923,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_9-test: # Testing @@ -2989,7 +2989,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda13_0-test: # Testing @@ -3457,7 +3457,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_6-test: # Testing @@ -3523,7 +3523,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_8-test: # Testing @@ -3589,7 +3589,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda12_9-test: # Testing @@ -3655,7 +3655,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14-cuda13_0-test: # Testing @@ -4123,7 +4123,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_6 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_6-test: # Testing @@ -4189,7 +4189,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_8 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_8-test: # Testing @@ -4255,7 +4255,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda12_9 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==12.9.4; platform_system == 'Linux' | nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda12_9-test: # Testing @@ -4321,7 +4321,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-bindings==13.0.3; platform_system == 'Linux' | nvidia-cuda-nvrtc==13.0.88; platform_system == 'Linux' | nvidia-cuda-runtime==13.0.96; platform_system == 'Linux' | nvidia-cuda-cupti==13.0.85; platform_system == 'Linux' | nvidia-cudnn-cu13==9.13.0.50; platform_system == 'Linux' | nvidia-cublas==13.1.0.3; platform_system == 'Linux' | nvidia-cufft==12.0.0.61; platform_system == 'Linux' | nvidia-curand==10.4.0.35; platform_system == 'Linux' | nvidia-cusolver==12.0.4.66; platform_system == 'Linux' | nvidia-cusparse==12.6.3.3; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.27.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' | nvidia-nvtx==13.0.85; platform_system == 'Linux' | nvidia-nvjitlink==13.0.88; platform_system == 'Linux' | nvidia-cufile==1.15.1.6; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_14t-cuda13_0-test: # Testing diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index a0ae234ab5669..3421e2b9af77d 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -30,14 +30,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -46,11 +46,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 628f624240127..764e631819ccc 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -27,14 +27,14 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -47,11 +47,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} # disable monitor in perf tests for more investigation diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml index 7b59e92386a33..11f5f10a55ad8 100644 --- a/.github/workflows/inductor-perf-test-b200.yml +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -80,7 +80,7 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -90,8 +90,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '10.0' test-matrix: | { include: [ @@ -104,12 +104,12 @@ jobs: secrets: inherit test-periodically: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -121,12 +121,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -138,11 +138,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm100 + name: cuda12.8-py3.10-gcc11-sm100 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 8209bf053a772..1c35fc6794537 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -95,8 +95,8 @@ jobs: # from trunk. Also use a memory-intensive runner here because memory is # usually the bottleneck runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '9.0' test-matrix: | { include: [ @@ -132,7 +132,7 @@ jobs: needs: build if: github.event.schedule == '15 0 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -149,7 +149,7 @@ jobs: needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -168,7 +168,7 @@ jobs: # needs one round of benchmark if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }} with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm90 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-cudagraphs_low_precision-${{ inputs.cudagraphs || 'false' }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 19f72ba453414..88a528ba1b075 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -80,15 +80,15 @@ jobs: opt_out_experiments: lf build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" # Every bit to make perf run faster helps runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -117,12 +117,12 @@ jobs: secrets: inherit test-nightly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 1-6' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -133,12 +133,12 @@ jobs: secrets: inherit test-weekly: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event.schedule == '0 7 * * 0' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} @@ -150,12 +150,12 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build if: github.event_name == 'workflow_dispatch' with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index b08d9865d15d3..f3e34d6ecb52f 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -37,8 +37,8 @@ jobs: needs: get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0;8.6' test-matrix: | { include: [ @@ -76,7 +76,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: periodic-dynamo-benchmarks-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.periodic-dynamo-benchmarks-build.outputs.docker-image }} test-matrix: ${{ needs.periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit @@ -138,8 +138,8 @@ jobs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -153,7 +153,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-smoke-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.inductor-smoke-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-smoke-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index ca9b57cab2ddb..0902026adb8ce 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -33,8 +33,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -52,7 +52,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 8a913c3b36a11..e524ed548b741 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -49,8 +49,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.6' runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | @@ -69,7 +69,7 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm86 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm86 docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index 89d6d63c72875..dd5cd832570f9 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -25,7 +25,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '8.0 9.0' test-matrix: | @@ -41,7 +41,7 @@ jobs: needs: opmicrobenchmark-build with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.opmicrobenchmark-build.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} secrets: inherit @@ -53,7 +53,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml with: runner: linux.12xlarge.memory - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' test-matrix: | @@ -68,7 +68,7 @@ jobs: needs: opmicrobenchmark-build-b200 with: timeout-minutes: 500 - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 5a90db9ab5737..325050392a393 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -90,6 +90,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda12.8-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 + cuda-arch-list: 8.6 test-matrix: | { include: [ { config: "nogpu_AVX512", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -97,7 +98,9 @@ jobs: { config: "nogpu_AVX512", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, + { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, + { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, ]} secrets: inherit @@ -113,40 +116,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-build: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 - cuda-arch-list: 8.6 - test-matrix: | - { include: [ - { config: "multigpu", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - { config: "multigpu", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu", owners: ["oncall:distributed"] }, - ]} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-test: - name: linux-jammy-cuda12.8-py3.10-gcc9 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-build - with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-build.outputs.test-matrix }} - secrets: inherit - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-build: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: 8.9 test-matrix: | { include: [ @@ -160,16 +137,16 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-debug-test: - name: linux-jammy-cuda12.8-py3.10-gcc9-debug + linux-jammy-cuda12_8-py3_10-gcc11-debug-test: + name: linux-jammy-cuda12.8-py3.10-gcc11-debug uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-cuda12_8-py3_10-gcc9-debug-build + - linux-jammy-cuda12_8-py3_10-gcc11-debug-build - target-determination with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-debug - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-debug-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-debug + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-debug-build.outputs.test-matrix }} secrets: inherit linux-jammy-cuda13_0-py3_10-gcc11-build: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 51e211a5ad2ad..f2483dff9a94c 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -318,14 +318,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-build: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-build: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '7.5' test-matrix: | { include: [ @@ -333,14 +333,14 @@ jobs: ]} secrets: inherit - linux-jammy-cuda12_8-py3_10-gcc9-inductor-test: - name: cuda12.8-py3.10-gcc9-sm75 + linux-jammy-cuda12_8-py3_10-gcc11-inductor-test: + name: cuda12.8-py3.10-gcc11-sm75 uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build + needs: linux-jammy-cuda12_8-py3_10-gcc11-inductor-build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75 - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm75 + docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-inductor-build.outputs.test-matrix }} secrets: inherit linux-noble-xpu-n-py3_10-build: diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 08fcd33402625..5a0273f0b745e 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -26,14 +26,14 @@ jobs: curr_ref_type: ${{ github.ref_type }} build: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-build.yml needs: - get-default-label-prefix with: runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ @@ -42,11 +42,11 @@ jobs: secrets: inherit test: - name: cuda12.8-py3.10-gcc9-sm80 + name: cuda12.8-py3.10-gcc11-sm80 uses: ./.github/workflows/_linux-test.yml needs: build with: - build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm80 + build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm80 docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 667c37727045b..eeba4c08a0c68 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -231,8 +231,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-cuda12.8-py3.12-gcc9-sm80 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks + build-environment: linux-jammy-cuda12.8-py3.12-gcc11-sm80 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' secrets: inherit @@ -283,6 +283,7 @@ jobs: name: linux-jammy-py3-clang12-executorch uses: ./.github/workflows/_linux-build.yml needs: get-label-type + if: false # Has been broken for a while with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3-clang12-executorch diff --git a/AGENTS.md b/AGENTS.md index 3d5436a02a85d..718217d3e663d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,6 +10,7 @@ - Do NOT run pre-commit, it is not setup - To run lint, run 'lintrunner -a' (which will autoapply changes) - Do NOT attempt to install dependencies, you do not have Internet access +- Do NOT create summary files unless explicitly asked - When you are ready to make a PR, do exactly these steps: - git stash -u - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 46dc550b1f37b..35a729ccc9f39 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -680,7 +680,7 @@ TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) { return false; } if (elem_type->kind() == AnyType::Kind) { - // List of Any can contains heterogenous types + // List of Any can contains heterogeneous types return false; } return true; diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 81b4643ac0418..7a650b9cbcf35 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -238,11 +238,18 @@ struct TORCH_CUDA_CPP_API CUDAEvent { } void moveHelper(CUDAEvent&& other) { - std::swap(flags_, other.flags_); - std::swap(is_created_, other.is_created_); - std::swap(was_recorded_, other.was_recorded_); - std::swap(device_index_, other.device_index_); - std::swap(event_, other.event_); + // Transfer ownership of all state from other to this + flags_ = other.flags_; + is_created_ = other.is_created_; + was_recorded_ = other.was_recorded_; + external_ = other.external_; + device_index_ = other.device_index_; + event_ = other.event_; + + // Reset other to a valid empty state to prevent double-free + // The moved-from object must not attempt to destroy the event + other.is_created_ = false; + other.event_ = cudaEvent_t{}; } }; diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 221f621ea1e06..b5f3d91692b9a 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -813,8 +813,43 @@ void smooth_l1_kernel(TensorIteratorBase& iter, double beta) { } void huber_kernel(TensorIterator& iter, double delta) { - AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "huber_cpu", [&]() { + // Special-case kHalf: compute in float for numerical stability + if (iter.dtype() == kHalf) { + const float delta_val(static_cast(delta)); + const Vectorized delta_vec(static_cast(delta)); + const Vectorized point_five_vec(static_cast(0.5)); + cpu_kernel_vec( + iter, + // scalar lambda: convert half -> float, compute in float, cast back to half + [&delta_val] (at::Half a, at::Half b) -> at::Half { + float af = static_cast(a); + float bf = static_cast(b); + float z = std::abs(af - bf); + float out = z < delta_val + ? 0.5f * z * z + : delta_val * (z - 0.5f * delta_val); + return static_cast(out); + }, + [&delta_vec, &point_five_vec] (Vectorized a, Vectorized b) { + auto [a0, a1] = convert_half_float(a); + auto [b0, b1] = convert_half_float(b); + auto z = (a0 - b0).abs(); + a0 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + z = (a1 - b1).abs(); + a1 = Vectorized::blendv( + point_five_vec * z * z, + delta_vec * (z - point_five_vec * delta_vec), + z >= delta_vec); + return convert_float_half(a0, a1); + } + ); + return; + } + else { + AT_DISPATCH_FLOATING_TYPES_AND(kBFloat16, iter.dtype(), "huber_cpu", [&]() { using Vec = Vectorized; const scalar_t delta_val(delta); const Vec delta_val_vec(delta_val); @@ -835,6 +870,7 @@ void huber_kernel(TensorIterator& iter, double delta) { z >= delta_val_vec); }); }); + } } void sigmoid_backward_kernel(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu index a917b0d6163fa..3f4f998d92cd6 100644 --- a/aten/src/ATen/native/cuda/GroupMM.cu +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -346,8 +346,9 @@ void dispatch_bf16_grouped_kernel_on_tile_size( bool small = (M <= 128 || N <= 128); cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; - if (sm10x) { + if (sm10x || sm11x) { if (small){ bf16bf16_grouped_gemm_impl_sm90_sm100< cutlass::arch::Sm100, diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index fd406829707a1..5c8b98105bb26 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -5,11 +5,69 @@ #include #endif +// ROCm 6.3 is planned to have these functions, but until then here they are. #if defined(USE_ROCM) #include #include #include -#define ATOMICADD unsafeAtomicAdd + +__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) + typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; + static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); + union { + __hip_bfloat162_raw bf162_raw; + vec_short2 vs2; + } u{static_cast<__hip_bfloat162_raw>(value)}; + u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2); + return static_cast<__hip_bfloat162>(u.bf162_raw); +#else + static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw)); + union u_hold { + __hip_bfloat162_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} + +__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) { +#if (defined(__gfx942__)) && \ + __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16) + // The api expects an ext_vector_type of half + typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162; + static_assert(sizeof(vec_fp162) == sizeof(__half2_raw)); + union { + __half2_raw h2r; + vec_fp162 fp16; + } u {static_cast<__half2_raw>(value)}; + u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16); + return static_cast<__half2>(u.h2r); +#else + static_assert(sizeof(__half2_raw) == sizeof(unsigned int)); + union u_hold { + __half2_raw h2r; + unsigned int u32; + }; + u_hold old_val, new_val; + old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + do { + new_val.h2r = __hadd2(old_val.h2r, value); + } while (!__hip_atomic_compare_exchange_strong( + (unsigned int*)address, &old_val.u32, new_val.u32, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)); + return old_val.h2r; +#endif +} +#define ATOMICADD preview_unsafeAtomicAdd #define NATIVE_ZERO_BF16 __float2bfloat16(0.0f) #else #define ATOMICADD atomicAdd diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 382a5a065b300..8971e05094651 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -958,8 +958,9 @@ void dispatch_fp8_rowwise_kernel_on_sm( const bool sm89 = properties != nullptr && properties->major == 8 && properties->minor == 9; const bool sm9x = properties != nullptr && properties->major == 9; const bool sm10x = properties != nullptr && properties->major == 10; + const bool sm11x = properties != nullptr && properties->major == 11; const bool sm12x = properties != nullptr && properties->major == 12; - if (!(sm89 || sm9x || sm10x || sm12x)) { + if (!(sm89 || sm9x || sm10x || sm11x || sm12x)) { TORCH_CHECK( false, "Rowwise scaling is not currently supported on your device"); } @@ -968,7 +969,7 @@ void dispatch_fp8_rowwise_kernel_on_sm( dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm90, Types...>(XQ, WQ, x_scale, w_scale, bias, out); - } else if (sm10x) { + } else if (sm10x || sm11x) { dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose< /*ArchTag=*/cutlass::arch::Sm100, Types...>(XQ, WQ, x_scale, w_scale, bias, out); diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index d9f126938b301..fcdf39b8a9f4b 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -147,6 +147,19 @@ class MetalShaderLibrary { const std::optional alpha = std::nullopt, const std::optional scalar_arg_type = std::nullopt); + template + void exec_unary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + template + void exec_binary_kernel_with_params( + TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name); + protected: virtual MTLLibrary_t getLibrary(); virtual MTLLibrary_t getLibrary( diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index cb488a3f5f117..5ca0ebe3de9bb 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -7,10 +7,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -630,4 +632,147 @@ inline bool needsGather(const TensorBase& t) { return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()); } +template +void MetalShaderLibrary::exec_unary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace at::mps; + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_unary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto inputTensor = iter.input(0); + auto outputTensor = iter.output(0); + uint32_t length = iter.numel(); + if (length == 0) { + return; + } + auto kernel_name = fmt::format("{}_{}_{}_{}{}", + name, + iter.is_contiguous() ? "dense" : "strided", + scalarToMetalTypeString(outputTensor), + scalarToMetalTypeString(inputTensor), + fmt::format("_{}", params_type_name)); + @autoreleasepool { + auto cplState = getPipelineStateForFunc(kernel_name); + + MPSStream* mpsStream = getCurrentMPSStream(); + dispatch_sync(mpsStream->queue(), ^() { + auto computeEncoder = mpsStream->commandEncoder(); + + getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor}); + + [computeEncoder setComputePipelineState:cplState]; + bind_iter_tensors(computeEncoder, iter); + if (!iter.is_contiguous()) { + mtl_setArgs<2>(computeEncoder, + outputTensor.sizes(), + inputTensor.strides(), + outputTensor.strides(), + inputTensor.ndimension()); + } + detail::mtl_setArg(computeEncoder, params, iter.is_contiguous() ? 2 : 6); + mtl_dispatch1DJob(computeEncoder, cplState, length); + + getMPSProfiler().endProfileKernel(cplState); + }); + } +} + +template +void MetalShaderLibrary::exec_binary_kernel_with_params(TensorIteratorBase& iter, + const std::string& name, + T params, + const std::string& params_type_name) { + using namespace mps; + // TODO: Figure a better place to downcast double scalars (probably in tensor iterator itself?) + // Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with + // double as common dtype (because Python floating point are always 64-bit values) + TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS"); + + // Skip for empty iterators + if (iter.numel() == 0) { + return; + } + + // Decompose 64-bit tensor into 32-bit ones + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + exec_binary_kernel_with_params(sub_iter, name, params, params_type_name); + } + return; + } + + auto convert_double_scalar = [](Tensor& t) { + if (t.dim() != 0) { + return; + } + if (t.scalar_type() == kDouble) { + t = t.to(kFloat); + } else if (t.scalar_type() == kComplexDouble) { + t = t.to(kComplexFloat); + } + }; + + Tensor input = iter.input(0); + Tensor other = iter.input(1); + Tensor out = iter.output(); + + convert_double_scalar(input); + convert_double_scalar(other); + + MPSStream* mpsStream = getCurrentMPSStream(); + const auto cast_needed = input.scalar_type() != other.scalar_type(); + const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + // TODO: Implicitly pass both input and output types to non-cast kernels + const auto kernel_name = cast_needed + ? fmt::format("{}_{}_cast_{}_{}", name, suffix, scalarToMetalTypeString(out), params_type_name) + : fmt::format("{}_{}_{}_{}_{}", + name, + suffix, + scalarToMetalTypeString(out), + scalarToMetalTypeString(input), + params_type_name); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = mpsStream->commandEncoder(); + auto binaryPSO = getPipelineStateForFunc(kernel_name); + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); + [computeEncoder setComputePipelineState:binaryPSO]; + // Set input and output tensors + bind_iter_tensors(computeEncoder, iter); + // Iterator is contiguous if all of its elements are dense in storage, + // i.e. it's true for both row-first and column-first tensors + if (iter.is_contiguous()) { + detail::mtl_setArg(computeEncoder, params, 3); + if (cast_needed) { + std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), + static_cast(c10::elementSize(other.scalar_type())), + static_cast(input.scalar_type()), + static_cast(other.scalar_type())}; + mtl_setBytes(computeEncoder, size_and_types, 4); + } + } else { + // Please note that shapes and strides of the iterator might be + // different than that of its operands, for example binary op + // between 4x4 tensor and scalar will result in 1D 16 element iterator + std::array ndim_and_types = {iter.ndim(), + static_cast(input.scalar_type()), + static_cast(other.scalar_type()), + static_cast(out.scalar_type())}; + mtl_setArgs<3>( + computeEncoder, params, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), ndim_and_types); + } + mtl_dispatch1DJob(computeEncoder, binaryPSO, iter.numel()); + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); +} + } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/kernels/Activation.h b/aten/src/ATen/native/mps/kernels/Activation.h new file mode 100644 index 0000000000000..34ad90dd7a2a3 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Activation.h @@ -0,0 +1,16 @@ +#pragma once + +template +struct ELUParams { + T alpha; + T scale; + T input_scale; +}; + +template +struct ELUBackwardParams { + T alpha; + T scale; + T input_scale; + bool is_result; +}; diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index ae1fda66c3b38..7d1f3aa5bacf6 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -1,3 +1,4 @@ +#include #include #include #include @@ -99,6 +100,59 @@ REGISTER_BINARY_OP(hardswish_backward, float, float); REGISTER_BINARY_OP(hardswish_backward, half, half); REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat); +struct elu_functor { + template + inline T operator()(const T self_, const ELUParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto self = static_cast(self_); + auto neg_res = alpha * (::metal::precise::exp(self * input_scale) - 1); + return static_cast(scale * (self < 0 ? neg_res : self)); + } +}; + +struct elu_backward_functor { + template + inline T operator()( + const T grad_output_, + const T self_, + ELUBackwardParams params) { + using op_T = opmath_t; + auto alpha = static_cast(params.alpha); + auto scale = static_cast(params.scale); + auto input_scale = static_cast(params.input_scale); + auto grad_output = static_cast(grad_output_); + auto self = static_cast(self_); + + if (params.is_result) { + auto neg_coef = input_scale * (self + alpha * scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } else { + auto neg_coef = input_scale * alpha * scale * + ::metal::precise::exp(self * input_scale); + return static_cast(grad_output * (self <= 0 ? neg_coef : scale)); + } + } +}; + +#define REGISTER_ELU_OP(T) \ + typedef ELUParams ELUParams_##T; \ + REGISTER_UNARY_ALPHA_OP(elu, T, ELUParams_##T, T); + +REGISTER_ELU_OP(float); +REGISTER_ELU_OP(half); +REGISTER_ELU_OP(bfloat); + +#define REGISTER_ELU_BACKWARD_OP(T) \ + typedef ELUBackwardParams ELUBackwardParams_##T; \ + REGISTER_BINARY_ALPHA_OP(elu_backward, T, ELUBackwardParams_##T, T); + +REGISTER_ELU_BACKWARD_OP(float); +REGISTER_ELU_BACKWARD_OP(half); +REGISTER_ELU_BACKWARD_OP(bfloat); + struct leaky_relu_functor { template inline T operator()(const T x, const T negative_slope) { diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index e437ea5ed7989..802c648c888d5 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -11,8 +11,6 @@ #include #include #include -#include -#include #include #include #include @@ -119,6 +117,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_mps_out) (const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -162,6 +164,10 @@ Tensor relu_mps(const Tensor& self) { TORCH_IMPL_FUNC(log_softmax_backward_mps_out) (const Tensor& grad_output, const Tensor& output, int64_t dim, ScalarType input_dtype, const Tensor& out) { + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(grad_output.scalar_type()), + "log_softmax for complex is not supported for MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(grad_output.scalar_type() != kBool, "log_softmax for bool is not supported for MPS"); using namespace mps; using CachedGraph = MPSUnaryGradCachedGraph; @@ -202,6 +208,7 @@ Tensor relu_mps(const Tensor& self) { } std::tuple log_sigmoid_forward_out_mps(const Tensor& self, Tensor& output, Tensor& buffer) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // NOTE: buffer is only used by CPU dispatch, we just ignore it here using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -698,194 +705,6 @@ Tensor log_sigmoid_backward_mps(const Tensor& grad_output, const Tensor& self, c } } -static void elu_variants_out_mps(const Tensor& self, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - const Tensor& result, - std::string func_name) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - auto resultMemFormat = result.suggest_memory_format(); - bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat)); - Tensor out; - if (executeGatherOp) { - out = at::empty_like(result, MemoryFormat::Contiguous); - } - - // Empty output - if (result.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to()) + ":" + - std::to_string(scale.to()) + ":" + std::to_string(input_scale.to()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) - - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(self)]; - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f shape:@[ @1 ] dataType:getMPSDataType(self)]; - - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* exponentTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - MPSGraphTensor* exponentMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:exponentTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* alphaTimesTensor = [mpsGraph multiplicationWithPrimaryTensor:exponentMinusOneTensor - secondaryTensor:alphaTensor - name:nil]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* fusedOutput = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:inputTensor - falsePredicateTensor:alphaTimesTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor:fusedOutput - secondaryTensor:scaleTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - if (out.has_storage()) { - result.copy_(out); - } - } -} - -// scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) )) -TORCH_IMPL_FUNC(elu_out_mps) -(const Tensor& self, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale, const Tensor& result) { - elu_variants_out_mps(self, alpha, scale, input_scale, result, "elu_out_mps"); -} - -TORCH_IMPL_FUNC(elu_backward_out_mps) -(const Tensor& grad_output, - const Scalar& alpha, - const Scalar& scale, - const Scalar& input_scale, - bool is_result, - const Tensor& self_or_result, - const Tensor& grad_input) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - auto gradMemFormat = grad_input.suggest_memory_format(); - bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && - grad_input.is_contiguous(gradMemFormat)); - Tensor out; - if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) { - out = at::empty_like(grad_input, MemoryFormat::Contiguous); - } - - // Empty output - if (grad_input.numel() == 0) { - return; - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" + - std::to_string(alpha.to()) + ":" + std::to_string(scale.to()) + ":" + - std::to_string(input_scale.to()) + ":" + std::to_string(is_result); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - MPSGraphTensor* selfOrResultTensor = mpsGraphRankedPlaceHolder(mpsGraph, self_or_result); - MPSGraphTensor* lessThanZeroGradTensor = nil; - - if (is_result) { - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* resultPlusAlphaTensor = [mpsGraph additionWithPrimaryTensor:selfOrResultTensor - secondaryTensor:alphaTensor - name:nil]; - auto constMul = scale.to() * input_scale.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:resultPlusAlphaTensor - secondaryTensor:constMulTensor - name:nil]; - } else { - MPSGraphTensor* inputScaleTensor = [mpsGraph constantWithScalar:input_scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* scaledInputTensor = [mpsGraph multiplicationWithPrimaryTensor:selfOrResultTensor - secondaryTensor:inputScaleTensor - name:nil]; - MPSGraphTensor* expTensor = [mpsGraph exponentWithTensor:scaledInputTensor name:nil]; - auto constMul = scale.to() * input_scale.to() * alpha.to(); - MPSGraphTensor* constMulTensor = [mpsGraph constantWithScalar:constMul - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - lessThanZeroGradTensor = [mpsGraph multiplicationWithPrimaryTensor:expTensor - secondaryTensor:constMulTensor - name:nil]; - } - - MPSGraphTensor* scaleTensor = [mpsGraph constantWithScalar:scale.to() - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0f - shape:@[ @1 ] - dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:selfOrResultTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* gradTensor = [mpsGraph selectWithPredicateTensor:predicateTensor - truePredicateTensor:scaleTensor - falsePredicateTensor:lessThanZeroGradTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = selfOrResultTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp); - Placeholder selfOrResultPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp); - Placeholder gradInputPlaceholder = - Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false); - - auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, selfOrResultPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder); - if (out.has_storage()) { - grad_input.copy_(out); - } - } -} - TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor& output) { using namespace mps; using CachedGraph = MPSUnaryCachedGraph; @@ -896,6 +715,7 @@ static void elu_variants_out_mps(const Tensor& self, if (output.numel() == 0) return; + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // this can't pass anyway because a 0-dimensional tensor has "size" 1, which // can't be evenly halved, but give a nicer error message here. TORCH_CHECK(self.dim() > 0, "glu does not support 0-dimensional tensors"); @@ -1009,6 +829,7 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Scalar& beta, const Scalar& threshold, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Not implemented for long"); // Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} * // \log(1 + \exp(\beta * x))` element-wise. // For numerical stability the implementation reverts to the linear function @@ -1159,6 +980,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int (const Tensor& self, const Tensor& result) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); if (result.numel() == 0) return; @@ -1207,6 +1030,8 @@ Tensor glu_backward_mps(const Tensor& grad_output, const Tensor& self, const int Tensor mish_backward_mps(const Tensor& grad_output, const Tensor& self) { using namespace mps; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Mish for complex is not supported for MPS"); Tensor grad_input = at::empty_like(self, self.suggest_memory_format()); if (grad_input.numel() == 0) @@ -1396,6 +1221,7 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { using CachedGraph = MPSUnaryCachedGraph; TORCH_CHECK(self.is_mps()); + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); // Empty output if (result.numel() == 0) diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm index cec8bfa2312e4..f6d3ad986ade0 100644 --- a/aten/src/ATen/native/mps/operations/ActivationKernel.mm +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -1,8 +1,10 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include #include #include +#include #include namespace at::native { @@ -41,6 +43,30 @@ static void hardswish_backward_kernel(at::TensorIterator& iter) { lib.exec_binary_kernel(iter, "hardswish_backward"); } +static void elu_kernel(TensorIteratorBase& iter, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_mps", [&]() { + ELUParams params{alpha.to(), scale.to(), input_scale.to()}; + lib.exec_unary_kernel_with_params( + iter, "elu", params, fmt::format("ELUParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + +static void elu_backward_kernel(TensorIteratorBase& iter, + const Scalar& alpha, + const Scalar& scale, + const Scalar& input_scale, + bool is_result) { + AT_DISPATCH_FLOATING_TYPES_AND2(c10::kHalf, c10::kBFloat16, iter.common_dtype(), "elu_backward_mps", [&]() { + ELUBackwardParams params{ + alpha.to(), scale.to(), input_scale.to(), is_result}; + lib.exec_binary_kernel_with_params( + iter, + "elu_backward", + params, + fmt::format("ELUBackwardParams_{}", mps::scalarToMetalTypeString(iter.common_dtype()))); + }); +} + static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { lib.exec_unary_kernel(iter, "leaky_relu", negative_slope); } @@ -56,6 +82,8 @@ static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& n REGISTER_DISPATCH(hardsigmoid_backward_stub, hardsigmoid_backward_kernel); REGISTER_DISPATCH(hardswish_stub, hardswish_kernel); REGISTER_DISPATCH(hardswish_backward_stub, hardswish_backward_kernel); +REGISTER_DISPATCH(elu_stub, elu_kernel); +REGISTER_DISPATCH(elu_backward_stub, elu_backward_kernel); REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel); REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel); diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 92f2b9c6fbf74..d75456c1ad3f0 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -80,6 +80,11 @@ static void grid_sampler_2d_mps_impl(Tensor& output, MPSGraphTensor* outputTensor_ = nil; }; + // Crashes with + // MPSGraphUtilities.mm:97:0: error: 'mps.sample_grid' op operand #0 must be tensor of mps native type values, but got + // 'tensor<2x3x5x20xcomplex>' + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "grid_sampler_2d is not supported for complex on MPS"); @autoreleasepool { std::string key = "grid_sampler_2d_mps" + getTensorsStringKey({input, grid}) + ":" + std::to_string(interpolation_mode) + ":" + std::to_string(padding_mode) + ":" + std::to_string(align_corners); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index ca19d121bb718..00f9c96b78af8 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -240,7 +240,7 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, bool check_errors) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS doesn't support complex types."); TORCH_CHECK(pivot, "linalg.lu_factor(): MPS doesn't allow pivot == False."); @@ -364,8 +364,7 @@ static void linalg_solve_out_mps_impl(const Tensor& A, const Tensor& info) { using namespace mps; - TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), - "linalg.lu_factor(): MPS doesn't support complex types."); + TORCH_CHECK(A.scalar_type() == kFloat && LU.scalar_type() == kFloat, "linalg.lu_factor(): MPS only supports floats."); Tensor A_t, B_t; // If 'left' is false, reinterpret the problem so that Ax = B becomes A^T â‹… (x^T) = B^T // Then we solve the normal "left" case on the transposed matrices and transpose x finally to get the output @@ -1058,7 +1057,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const using namespace mps; checkInputsSolver(A, B, left, "linalg.solve_triangular"); - TORCH_CHECK(!A.is_complex() && !B.is_complex(), "linalg.solve.triangular(); Not supported for complex yet!"); + TORCH_CHECK(A.scalar_type() == kFloat && B.scalar_type() == kFloat, + "linalg.solve.triangular(); Only float is supported!"); Tensor A_t, B_t; std::tie(B_t, A_t) = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/ nullptr); at::native::resize_output(out, B_t.sizes()); diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index f0bbcdabfa5cd..11ee09d6e23f2 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -416,6 +416,8 @@ static void nllnd_loss_forward_impl(Tensor& output, int64_t reduction, int64_t ignore_index, bool is2D) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(output.scalar_type()), + "nlld_loss for complex is not supported for MPS"); std::vector reshapedTarget(target_arg.sizes().begin(), target_arg.sizes().end()); reshapedTarget.push_back(1); @@ -824,6 +826,9 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output, Tensor& huber_loss_out_mps(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& output) { std::string op_name = __func__; using namespace mps; + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "MPS doesn't know how to do square_i64"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "huber_loss for complex is not supported for MPS"); TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.") TORCH_CHECK(target.is_same_size(input), op_name + ": target and input tensors must have identical shapes") TORCH_CHECK(output.is_mps()); diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 2d466f7c79436..ecd5f12df17f8 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -597,6 +597,7 @@ static void avg_pool2d_template(const Tensor& input, bool count_include_pad, const std::optional divisor_override, const std::string& op_name) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), "Not implemented for complex"); const Tensor& grad_output = *(at::borrow_from_optional_tensor(grad_output_opt)); const bool is_backward_pass = grad_output.defined(); const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; @@ -915,6 +916,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& output, const Tensor& indices) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); bool use_graph = use_graph_for_max_pool2d(kernel_size, stride); if (use_graph) { auto indices_memory_format = indices.suggest_memory_format(); @@ -967,6 +970,8 @@ Tensor mps_max_pool2d_backward(const Tensor& grad_output, bool ceil_mode, const Tensor& indices, const Tensor& grad_input) { + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "Max pooling for complex is not supported for MPS"); mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) { MPSGraph* mpsGraph = cachedGraph.graph(); return [mpsGraph maxPooling2DGradientWithGradientTensor:cachedGraph.gradOutputTensor diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 3747f314adfa1..e634eefee2058 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -269,17 +269,22 @@ static void reduction_out_mps(const Tensor& input_t, name:nil]; castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil]; } else if (reduction_type == MPSReductionType::NANSUM) { - // Create a 0 tensor of the same shape as inputTensor - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; - // Find NaNs - MPSGraphTensor* nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; - // Replace NaNs with 0 - MPSGraphTensor* nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:zeros - falsePredicateTensor:castInputTensor - name:nil]; - // Sum - castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + // Integral types cannot contain NaN, so just do regular sum + if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) { + castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; + } else { + // Create a 0 tensor of the same shape as inputTensor + auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; + // Find NaNs + auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; + // Replace NaNs with 0 + auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask + truePredicateTensor:zeros + falsePredicateTensor:castInputTensor + name:nil]; + // Sum + castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; + } } MPSGraphTensor* outputTensor = castOutputTensor; @@ -442,6 +447,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t, const std::optional& correction, bool keepdim, StdVarType stdVarType) { + TORCH_CHECK_NOT_IMPLEMENTED(input_t.scalar_type() != kLong, "Not implemented for MPS"); using CachedGraph = MPSUnaryCachedGraph; IntArrayRef input_shape = input_t.sizes(); diff --git a/aten/src/ATen/native/mps/operations/SoftMax.mm b/aten/src/ATen/native/mps/operations/SoftMax.mm index 8f70e216dcae8..8eb24d0cb68bf 100644 --- a/aten/src/ATen/native/mps/operations/SoftMax.mm +++ b/aten/src/ATen/native/mps/operations/SoftMax.mm @@ -39,6 +39,7 @@ static void get_shapes(MPSShape* input_shape_readonly, TORCH_IMPL_FUNC(softmax_mps_out) (const Tensor& input_, const int64_t dim, const bool half_to_float, const Tensor& output) { TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); + TORCH_CHECK(c10::isFloatingType(input_.scalar_type()), "softmax only supported for floating types"); static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); if (input_.numel() == 0) { diff --git a/aten/src/ATen/native/mps/operations/SummaryOps.mm b/aten/src/ATen/native/mps/operations/SummaryOps.mm index e709ec2d4f618..21cae885c3685 100644 --- a/aten/src/ATen/native/mps/operations/SummaryOps.mm +++ b/aten/src/ATen/native/mps/operations/SummaryOps.mm @@ -18,6 +18,10 @@ MPSStream* stream = getCurrentMPSStream(); bool has_weights = weights.defined(); + // Crashes with + // MPSGraphUtilities.mm:190:0: error: 'mps.scatter' op operand #2 must be tensor of int values, but got 'tensor<5xi1>' + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kBool, "bincount is not supported for Bool"); + @autoreleasepool { std::string key = "bincount_mps_impl" + getTensorsStringKey({self, weights}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9a1c7c790afaa..fd88794d38f52 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -12064,8 +12064,7 @@ device_check: NoCheck # TensorIterator python_module: nn dispatch: - CPU, CUDA: elu_out - MPS: elu_out_mps + CPU, CUDA, MPS: elu_out - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor structured_delegate: elu.out @@ -12078,8 +12077,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: elu_backward_out - MPS: elu_backward_out_mps + CPU, CUDA, MPS: elu_backward_out - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor structured_delegate: elu_backward.grad_input diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index a522e7ab76cf4..923b7119a42fc 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -65,6 +65,7 @@ list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cuda_event_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_exchange_device_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda_generator_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu diff --git a/aten/src/ATen/test/cuda_event_test.cpp b/aten/src/ATen/test/cuda_event_test.cpp new file mode 100644 index 0000000000000..7c58688e1ef9d --- /dev/null +++ b/aten/src/ATen/test/cuda_event_test.cpp @@ -0,0 +1,36 @@ +#include + +#include +#include +#include + +TEST(CUDAEventTest, testCUDAExternalEvent) { + if (!at::cuda::is_available()) { + return; + } + + // Create two external CUDA events + unsigned int flags = cudaEventDefault | cudaEventExternal; + auto event1 = at::cuda::CUDAEvent(flags); + auto event2 = at::cuda::CUDAEvent(flags); + // Ensure external CUDAEvent remain valid and functional after being moved. + auto start_event = std::move(event1); + auto end_event = std::move(event2); + + auto stream = at::cuda::getStreamFromPool(); + at::cuda::setCurrentCUDAStream(stream); + + auto graph = at::cuda::CUDAGraph(); + graph.capture_begin(); + start_event.record(); + at::cuda::sleep(100000); + end_event.record(); + graph.capture_end(); + + // External events should correctly record timestamps even when used inside + // CUDA graphs, and elapsed_time() between them should be positive. + stream.synchronize(); + graph.replay(); + at::cuda::device_synchronize(); + EXPECT_TRUE(start_event.elapsed_time(end_event) > 0); +} diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 00fc03bbd0fcf..56bc75e01adb1 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -48,7 +48,7 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } -void StorageImpl::incref_pyobject() const { +void StorageImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -59,12 +59,12 @@ void StorageImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void StorageImpl::decref_pyobject() const { +void StorageImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool StorageImpl::try_incref_pyobject() const { +bool StorageImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index c7dbd5c1f005b..8df32f552c754 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -105,11 +105,11 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { data_ptr_.clear(); } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; size_t nbytes() const { // OK to do this instead of maybe_as_int as nbytes is guaranteed positive diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 94a7375cc32fb..c890d6d084eb3 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -988,7 +988,7 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } -void TensorImpl::incref_pyobject() const { +void TensorImpl::incref_pyobject() const noexcept { // Because intrusive_ptr incref uses relaxed memory order, we need to // do an acquire fence to ensure that the kHasPyObject bit was // observed before the load of the PyObject* below. @@ -999,12 +999,12 @@ void TensorImpl::incref_pyobject() const { (*pyobj_slot_.pyobj_interpreter())->incref(obj); } -void TensorImpl::decref_pyobject() const { +void TensorImpl::decref_pyobject() const noexcept { PyObject* obj = pyobj_slot_.load_pyobj(); (*pyobj_slot_.pyobj_interpreter())->decref(obj); } -bool TensorImpl::try_incref_pyobject() const { +bool TensorImpl::try_incref_pyobject() const noexcept { c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); if (C10_UNLIKELY(!interp)) { return false; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 71a0195dde773..42b6bb1e80d2e 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2178,11 +2178,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } - void incref_pyobject() const override final; + void incref_pyobject() const noexcept override final; - void decref_pyobject() const override final; + void decref_pyobject() const noexcept override final; - bool try_incref_pyobject() const override final; + bool try_incref_pyobject() const noexcept override final; private: // See NOTE [std::optional operator usage in CUDA] diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 1ff0c9a12ac78..380e7939ff76c 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,22 +20,6 @@ } \ } while (0) -#define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \ - do { \ - CUresult __err = EXPR; \ - if (__err != CUDA_SUCCESS) { \ - const char* err_str; \ - CUresult get_error_str_err [[maybe_unused]] = \ - c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ - if (get_error_str_err != CUDA_SUCCESS) { \ - TORCH_WARN("CUDA driver error: unknown error"); \ - } else { \ - TORCH_WARN("CUDA driver error: ", err_str); \ - } \ - goto NEXT; \ - } \ - } while (0) - // The integer in the second column specifies the requested CUDA Driver API // version. The dynamic loader will accept a driver with a newer version, but it // ensures that the requested symbol exists in *at least* the specified version diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 0c8f55f5061ab..f3c4ab0dc7cbc 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -68,6 +68,10 @@ inline bool has_pyobject(uint64_t combined_refcount) { return (combined_refcount & kHasPyObject) != 0; } +inline bool is_uniquely_owned(uint64_t combined_refcount) { + return (combined_refcount & ~detail::kHasPyObject) == detail::kUniqueRef; +} + // The only requirement for refcount increment is that it happens-before // decrement, so no additional memory ordering is needed. inline uint64_t atomic_combined_refcount_increment( @@ -287,9 +291,9 @@ class C10_API intrusive_ptr_target { * These two methods are called when the refcount transitions between one * and two and the object has a PyObject wrapper. */ - virtual void incref_pyobject() const {} - virtual void decref_pyobject() const {} - virtual bool try_incref_pyobject() const { + virtual void incref_pyobject() const noexcept {} + virtual void decref_pyobject() const noexcept {} + virtual bool try_incref_pyobject() const noexcept { return false; } @@ -363,7 +367,7 @@ class intrusive_ptr final { template friend class pybind11::class_; - void retain_() { + void retain_() noexcept { if (target_ != NullType::singleton()) { uint64_t combined = detail::atomic_combined_refcount_increment( target_->combined_refcount_, detail::kReferenceCountOne); @@ -377,9 +381,7 @@ class intrusive_ptr final { // PyObject. In other words, we need to ensure that the PyObject stays // alive now that we have a C++ reference to this object in addition to // the PyObject itself. - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { target_->incref_pyobject(); } } else { @@ -392,51 +394,60 @@ class intrusive_ptr final { void reset_() noexcept { if (target_ != NullType::singleton()) { - if (is_uniquely_owned()) { - // Both counts are 1, so there are no weak references and - // we are releasing the last strong reference. No other - // threads can observe the effects of this target_ deletion - // call (e.g. calling use_count()) without a data race. - target_->combined_refcount_.store(0, std::memory_order_relaxed); - delete target_; + reset_not_null_(target_); + } + } + + // C10_NOINLINE to keep binary size a bit smaller. We pass TTarget* here + // to avoid an extra pointer dereference in the call from reset_(). + C10_NOINLINE static void reset_not_null_(TTarget* target) noexcept { + if (detail::is_uniquely_owned( + target->combined_refcount_.load(std::memory_order_acquire))) { + // Both counts are 1, so there are no weak references and + // we are releasing the last strong reference. No other + // threads can observe the effects of this target deletion + // call (e.g. calling use_count()) without a data race. + target->combined_refcount_.store(0, std::memory_order_relaxed); + delete target; + return; + } + + auto combined_refcount = detail::atomic_combined_refcount_decrement( + target->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + if (detail::weakcount(combined_refcount) == 1) { + delete target; return; } - - auto combined_refcount = detail::atomic_combined_refcount_decrement( - target_->combined_refcount_, detail::kReferenceCountOne); - uint32_t new_refcount = detail::refcount(combined_refcount); - bool has_pyobject = detail::has_pyobject(combined_refcount); - if (new_refcount == 0) { - bool should_delete = detail::weakcount(combined_refcount) == 1; - // See comment above about weakcount. As long as refcount>0, - // weakcount is one larger than the actual number of weak references. - // So we need to decrement it here. - if (!should_delete) { - // justification for const_cast: release_resources is basically a - // destructor and a destructor always mutates the object, even for - // const objects. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast*>(target_) - ->release_resources(); - should_delete = detail::atomic_weakcount_decrement( - target_->combined_refcount_) == 0; - } - if (should_delete) { - delete target_; - } - } else if constexpr (detail::TargetTraits::can_have_pyobject) { - // If the refcount transitioned from 2 to 1, we need to decref the - // PyObject. In other words, we don't want to keep the PyObject alive if - // there are no C++ references to this object other than the PyObject - // itself. - if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) { - target_->decref_pyobject(); - } - } else { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !has_pyobject, - "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + // See comment above about weakcount. As long as refcount>0, + // weakcount is one larger than the actual number of weak references. + // So we need to decrement it here. + release_resources_and_decrement_weakrefs_(target); + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (has_pyobject && new_refcount == 1) { + target->decref_pyobject(); } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } + } + + C10_NOINLINE static void release_resources_and_decrement_weakrefs_( + TTarget* target) noexcept { + // justification for const_cast: release_resources is basically a + // destructor and a destructor always mutates the object, even for + // const objects. + const_cast*>(target)->release_resources(); + if (detail::atomic_weakcount_decrement(target->combined_refcount_) == 0) { + delete target; } } @@ -607,9 +618,8 @@ class intrusive_ptr final { */ bool is_uniquely_owned() const noexcept { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); - uint64_t combined = - target_->combined_refcount_.load(std::memory_order_acquire); - return (combined & ~detail::kHasPyObject) == detail::kUniqueRef; + return detail::is_uniquely_owned( + target_->combined_refcount_.load(std::memory_order_acquire)); } /** @@ -1174,9 +1184,7 @@ inline void incref(intrusive_ptr_target* self) { self->combined_refcount_, detail::kReferenceCountOne); #ifndef C10_MOBILE - if (C10_UNLIKELY( - detail::has_pyobject(combined) && - detail::refcount(combined) == 2)) { + if (detail::has_pyobject(combined) && detail::refcount(combined) == 2) { self->incref_pyobject(); } #else diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 3bd9eff0fee63..d7eeb10caba1b 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -893,11 +893,13 @@ class DeviceCachingAllocator { } bool release_cached_blocks(MempoolId_t mempool_id) { + bool streams_synced = false; if (mempool_id.first == 0 && mempool_id.second == 0 && captures_underway.empty()) { synchronize_and_free_events(); // See Note [Safe to Free Blocks on BlockPool] c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; release_blocks(large_blocks); release_blocks(small_blocks); @@ -916,6 +918,12 @@ class DeviceCachingAllocator { continue; } } + + if (!streams_synced) { + // See Note [Safe to Free Blocks on BlockPool] + c10::xpu::syncStreamsOnDevice(device_index); + streams_synced = true; + } TORCH_INTERNAL_ASSERT(it->second->use_count == 0); release_blocks(it->second->small_blocks); release_blocks(it->second->large_blocks); @@ -1219,6 +1227,63 @@ class DeviceCachingAllocator { allowed_memory_maximum = static_cast(fraction * device_total); set_fraction = true; } + + void createOrIncrefPool( + MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr) { + std::scoped_lock lock(mutex); + create_or_incref_pool(mempool_id, allocator); + } + + int getPoolUseCount(MempoolId_t mempool_id) { + std::scoped_lock lock(mutex); + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + return 0; + } + return it->second->use_count; + } + + // Called by XPUGraph::capture_begin + void beginAllocateToPool( + MempoolId_t mempool_id, + std::function filter) { + std::lock_guard lock(mutex); + create_or_incref_pool(mempool_id); + auto not_found = std::all_of( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first != mempool_id; }); + TORCH_CHECK( + not_found, "beginAllocateToPool: already recording to mempool_id"); + captures_underway.emplace_back(mempool_id, std::move(filter)); + } + + // Called by XPUGraph::capture_end + void endAllocateToPool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + + auto it = std::find_if( + captures_underway.begin(), + captures_underway.end(), + [&](const auto& entry) { return entry.first == mempool_id; }); + TORCH_INTERNAL_ASSERT( + it != captures_underway.end(), + "endAllocatePool: not currently recording to mempool_id"); + captures_underway.erase(it); + } + + // Called by XPUGraph::reset and MemPool::~MemPool() + void releasePool(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + auto pp = get_private_pool(mempool_id); + auto uc = --(pp->use_count); + TORCH_INTERNAL_ASSERT(uc >= 0); + if (uc == 0) { + bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second; + TORCH_INTERNAL_ASSERT(inserted); + } + } }; static void local_raw_delete(void* ptr); @@ -1408,6 +1473,39 @@ class XPUAllocator : public DeviceAllocator { ". Please set within (0, 1]."); device_allocators[device]->setMemoryFraction(fraction); } + + void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator) { + assertValidDevice(device); + device_allocators[device]->createOrIncrefPool( + std::move(mempool_id), allocator); + } + + void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + assertValidDevice(device); + device_allocators[device]->beginAllocateToPool( + std::move(mempool_id), std::move(filter)); + } + + void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->endAllocateToPool(mempool_id); + } + + void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + device_allocators[device]->releasePool(std::move(mempool_id)); + } + + int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + assertValidDevice(device); + return device_allocators[device]->getPoolUseCount(std::move(mempool_id)); + } }; static XPUAllocator allocator; @@ -1464,6 +1562,92 @@ void setMemoryFraction(double fraction, DeviceIndex device) { return allocator.setMemoryFraction(fraction, device); } +void createOrIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + XPUAllocator* allocator_ptr) { + return allocator.createOrIncrefPool(device, mempool_id, allocator_ptr); +} + +void beginAllocateToPool( + c10::DeviceIndex device, + MempoolId_t mempool_id, + std::function filter) { + return allocator.beginAllocateToPool(device, mempool_id, std::move(filter)); +} + +void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.endAllocateToPool(device, mempool_id); +} + +void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.releasePool(device, mempool_id); +} + +int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + return allocator.getPoolUseCount(device, mempool_id); +} + REGISTER_ALLOCATOR(kXPU, &allocator) } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +// uid_ is incremented when a user creates a MemPool, +// +// uuid_ is incremented when XPUGraph creates a MemPool +// as a result of a user not providing a pool. + +std::atomic MemPool::uid_{1}; +std::atomic MemPool::uuid_{1}; + +MemPool::MemPool( + XPUCachingAllocator::XPUAllocator* allocator, + bool is_user_created, + bool use_on_oom) + : allocator_(allocator), is_user_created_(is_user_created) { + if (is_user_created_) { + id_ = {0, uid_++}; + } else { + id_ = {uuid_++, 0}; + } + device_ = c10::xpu::current_device(); + XPUCachingAllocator::createOrIncrefPool(device_, id_, allocator); + if (use_on_oom) { + // XPU doesn't support use_on_oom yet + TORCH_WARN( + "XPUCachingAllocator::MemPool: use_on_oom is not supported on XPU"); + } +} + +MemPool::~MemPool() { + TORCH_INTERNAL_ASSERT(use_count() == 1); + XPUCachingAllocator::releasePool(device_, id_); + c10::xpu::XPUCachingAllocator::emptyCache(id_); // release cached blocks +} + +MempoolId_t MemPool::id() { + return id_; +} + +XPUCachingAllocator::XPUAllocator* MemPool::allocator() { + return allocator_; +} + +int MemPool::use_count() { + return XPUCachingAllocator::getPoolUseCount(device_, id_); +} + +c10::DeviceIndex MemPool::device() { + return device_; +} + +MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { + if (is_user_created) { + return {0, uid_++}; + } + return {uuid_++, 0}; +} + +} // namespace c10::xpu diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index bbb20a5b2ecdf..c55de309032e0 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -33,4 +33,59 @@ C10_XPU_API double getMemoryFraction(DeviceIndex device); C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); +class XPUAllocator; + +C10_XPU_API void createOrIncrefPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + XPUAllocator* allocator = nullptr); + +C10_XPU_API void beginAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id, + std::function filter); + +C10_XPU_API void endAllocateToPool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API void releasePool( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + +C10_XPU_API int getPoolUseCount( + c10::DeviceIndex device, + c10::MempoolId_t mempool_id); + } // namespace c10::xpu::XPUCachingAllocator + +namespace c10::xpu { + +using c10::CaptureId_t; +using c10::MempoolId_t; +struct C10_XPU_API MemPool { + MemPool( + XPUCachingAllocator::XPUAllocator* allocator = nullptr, + bool is_user_created = true, + bool use_on_oom = false); + MemPool(const MemPool&) = delete; + MemPool(MemPool&&) = default; + MemPool& operator=(const MemPool&) = delete; + MemPool& operator=(MemPool&&) = default; + ~MemPool(); + + MempoolId_t id(); + XPUCachingAllocator::XPUAllocator* allocator(); + int use_count(); + c10::DeviceIndex device(); + static MempoolId_t graph_pool_handle(bool is_user_created = true); + + private: + static std::atomic uid_; + static std::atomic uuid_; + XPUCachingAllocator::XPUAllocator* allocator_; + bool is_user_created_; + MempoolId_t id_; + c10::DeviceIndex device_; +}; +} // namespace c10::xpu diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index bac1fa7daac01..5faad21f9f6cd 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -113,6 +113,12 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a") endif() endif() + # We will need to gate against CUDA version, because sm_110a is available on CUDA 13.0+ + if("${_arch}" STREQUAL "110a" AND CUDA_VERSION VERSION_GREATER_EQUAL 13.0) + if(_existing_arch_flags MATCHES ".*compute_110.*") + list(APPEND _file_compile_flags "-gencode;arch=compute_110a,code=sm_110a") + endif() + endif() if("${_arch}" STREQUAL "120a") if(_existing_arch_flags MATCHES ".*compute_120.*") list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a") @@ -132,13 +138,13 @@ if(INTERN_BUILD_ATEN_OPS) _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu" - "89;90a;100a;103a;120a;121a") + "89;90a;100a;103a;110a;120a;121a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" "90a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" - "90a;100a;103a") + "90a;100a;103a;110a") endif() diff --git a/docs/source/community/governance.rst b/docs/source/community/governance.rst index cea24593dca83..ebfadf4e0f69b 100644 --- a/docs/source/community/governance.rst +++ b/docs/source/community/governance.rst @@ -132,7 +132,7 @@ The Process for Nomination * Each module has its own process. Please contact module maintainers for more information. However, if there is no process identified, you can file a request to the core - maintainers by submitting `this form `__. + maintainers by submitting `this form `__. Core maintainers are meeting every three months. * If you are submitting a request to the core maintainers, the information in your request must include the following items: diff --git a/docs/source/distributed.md b/docs/source/distributed.md index ca1fe3b5e9099..6840bbb893bf7 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -987,6 +987,24 @@ In addition, `TORCH_DISTRIBUTED_DEBUG=DETAIL` can be used in conjunction with `T collective desynchronization checks will work for all applications that use `c10d` collective calls backed by process groups created with the {func}`torch.distributed.init_process_group` and {func}`torch.distributed.new_group` APIs. + +### torch.distributed.debug HTTP Server + +The `torch.distributed.debug` module provides a HTTP server that can be used to debug distributed applications. The server can +be started by calling {func}`torch.distributed.debug.start_debug_server`. This +allows users to collect data across all workers at runtime. + +```{eval-rst} +.. automodule:: torch.distributed.debug + :members: + :undoc-members: + :show-inheritance: + :special-members: __init__ + :member-order: bysource + +``` + + ## Logging In addition to explicit debugging support via {func}`torch.distributed.monitored_barrier` and `TORCH_DISTRIBUTED_DEBUG`, the underlying C++ library of `torch.distributed` also outputs log diff --git a/setup.py b/setup.py index 314f719ea67f0..f15e7bbdd0ac4 100644 --- a/setup.py +++ b/setup.py @@ -1089,6 +1089,60 @@ def check_pydep(importname: str, module: str) -> None: class build_ext(setuptools.command.build_ext.build_ext): + def _wrap_headers_with_macro(self, include_dir: Path) -> None: + """Wrap all header files with #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION). + + Excludes: + - torch/headeronly/* + - torch/csrc/stable/* + - torch/csrc/inductor/aoti_torch/c/ (only shim headers) + - torch/csrc/inductor/aoti_torch/generated/ + + This method is idempotent - it will not wrap headers that are already wrapped. + """ + header_extensions = (".h", ".hpp", ".cuh") + header_files = [ + f for ext in header_extensions for f in include_dir.rglob(f"*{ext}") + ] + + # Paths to exclude from wrapping (relative to include_dir) + exclude_dir_patterns = [ + "torch/headeronly/", + "torch/csrc/stable/", + "torch/csrc/inductor/aoti_torch/c/", + "torch/csrc/inductor/aoti_torch/generated/", + ] + + # Marker to detect if a header is already wrapped + wrap_start_marker = ( + "#if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + for header_file in header_files: + rel_path = header_file.relative_to(include_dir).as_posix() + + if any(rel_path.startswith(pattern) for pattern in exclude_dir_patterns): + report(f"Skipping header: {rel_path}") + continue + + original_content = header_file.read_text(encoding="utf-8") + + # Check if already wrapped (idempotency check) + if original_content.startswith(wrap_start_marker): + report(f"Already wrapped, skipping: {rel_path}") + continue + + wrapped_content = ( + wrap_start_marker + + f"{original_content}" + + "\n#else\n" + + '#error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."\n' + + "#endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)\n" + ) + + header_file.write_text(wrapped_content, encoding="utf-8") + report(f"Wrapped header: {rel_path}") + def _embed_libomp(self) -> None: # Copy libiomp5.dylib/libomp.dylib inside the wheel package on MacOS build_lib = Path(self.build_lib) @@ -1256,6 +1310,15 @@ def run(self) -> None: super().run() + # Wrap headers with TORCH_STABLE_ONLY and TORCH_TARGET_VERSION guards + build_lib = Path(self.build_lib) + build_torch_include_dir = build_lib / "torch" / "include" + if build_torch_include_dir.exists(): + report( + "-- Wrapping header files with if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)" + ) + self._wrap_headers_with_macro(build_torch_include_dir) + if IS_DARWIN: self._embed_libomp() diff --git a/test/complex_tensor/test_complex_tensor.py b/test/complex_tensor/test_complex_tensor.py new file mode 100644 index 0000000000000..dbb14d93f972a --- /dev/null +++ b/test/complex_tensor/test_complex_tensor.py @@ -0,0 +1,238 @@ +# Owner(s): ["module: complex"] +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.distributed as dist + + +# Support both when imported from elsewhere or directly as a file +try: + from .utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) +except ImportError: + from utils import ( + COMPLEX_DTYPES, + Descriptor, + force_test_op_db, + get_overload_packet_from_name, + implemented_op_db, + TestCase, + Variant, + ) + +from torch._subclasses.complex_tensor._ops.common import ComplexTensorMode +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) +from torch.testing._internal.common_utils import ( + run_tests, + TestGradients, + unMarkDynamoStrictTest, +) + + +if TYPE_CHECKING: + from torch.testing._internal.opinfo.core import OpInfo + +aten = torch.ops.aten + +SKIPS = { + Descriptor(op=aten.empty_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.randn_like, variant=None): "Non-deterministic output", + Descriptor(op=aten.angle, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.reciprocal, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.rsqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.select, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.asin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sgn, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cumprod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.slice, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sqrt, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.true_divide, variant=Variant.GradCheck + ): "Numerical inconsistency", + Descriptor(op=aten.prod, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.div, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.expm1, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.var, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.bmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.diagonal, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sinh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.abs, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.sin, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.atan, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.acosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cos, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.cosh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.addmm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.pow, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.log1p, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.tanh, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mm, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor(op=aten.to, variant=Variant.GradCheck): "Numerical inconsistency", + Descriptor( + op=aten.any, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.all, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.allclose, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten._conj_physical, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.cumprod, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_add, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.diagonal_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.flip, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_fill, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.masked_scatter, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.rsub, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.ne, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.squeeze, variant=Variant.Distributed + ): "does not have a sharding strategy registered", + Descriptor( + op=aten.index_select, variant=Variant.Distributed + ): "Sharding propagation failed", + Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support", + Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support", +} + +EXTRA_KWARGS = { + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Op): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-5, + "atol": 5e-5, + }, + Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 1e-4, + "atol": 1e-5, + }, + Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-2, + "atol": 2e-6, + }, + Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): { + "rtol": 2e-6, + "atol": 1e-2, + }, +} + + +class TestComplexTensor(TestCase): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported, + allowed_dtypes=list(COMPLEX_DTYPES), + ) + def test_consistency(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + @ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_maybe_error(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Op) + + +@unMarkDynamoStrictTest +class TestComplexBwdGradients(TestGradients): + _default_dtype_check_enabled = True + + @ops( + implemented_op_db, + dtypes=OpDTypes.supported_backward, + allowed_dtypes=[torch.complex128], + ) + def test_fn_grad(self, device: str, dtype: torch.dtype, op: OpInfo) -> None: + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=Variant.GradCheck, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + if dtype not in op.supported_backward_dtypes(torch.device(device).type): + self.skipTest(f"Skipped! {dtype=} is not in supported backward dtypes!") + + with ComplexTensorMode(): + op.gradcheck_fast_mode = False + self._grad_test_helper(device, dtype, op, op.get_op()) + + +instantiate_device_type_tests(TestComplexTensor, globals()) +instantiate_device_type_tests(TestComplexBwdGradients, globals()) + + +if dist.is_available(): + from torch.testing._internal.common_distributed import MultiProcessTestCase + + @unMarkDynamoStrictTest + class TestComplexDistributed(TestCase, MultiProcessTestCase): + @ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES)) + def test_distributed(self, device, dtype, op: OpInfo): + self.check_consistency(device, dtype, op, Variant.Distributed) + + instantiate_device_type_tests(TestComplexDistributed, globals()) + +if __name__ == "__main__": + run_tests() diff --git a/test/complex_tensor/utils.py b/test/complex_tensor/utils.py new file mode 100644 index 0000000000000..d2a1e1d312264 --- /dev/null +++ b/test/complex_tensor/utils.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from enum import auto, Enum +from typing import Any, TYPE_CHECKING + +import torch +import torch.distributed as dist +from torch._subclasses.complex_tensor._ops.common import ( + _as_complex_tensor, + _as_interleaved, + _get_op_name, + COMPLEX_OPS_TABLE, + COMPLEX_TO_REAL, + FORCE_TEST_LIST, + OpOverloadPacket, +) +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import TestCase as PytorchTestCase +from torch.utils._pytree import tree_flatten + + +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.distributed.tensor import DTensor + from torch.testing._internal.opinfo.core import OpInfo + +COMPLEX_DTYPES = set(COMPLEX_TO_REAL) + + +class Variant(Enum): + Op = auto() + GradCheck = auto() + Distributed = auto() + + +def _as_local(arg: DTensor | Any) -> torch.Tensor | Any: + if not (dist.is_available() and isinstance(arg, dist.tensor.DTensor)): + return arg + + return arg.full_tensor() + + +def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any: + if not isinstance(arg, torch.Tensor): + return arg + + return dist.tensor.DTensor.from_local(_as_complex_tensor(arg)) + + +TRANSFORM_FUNCS = { + Variant.Op: _as_complex_tensor, + Variant.Distributed: _as_complex_dtensor, +} + + +@dataclass(frozen=True, kw_only=True) +class Descriptor: + op: OpOverloadPacket + variant: Variant | None + device_type: str | None = field(default=None) + dtype: torch.dtype | None = field(default=None) + + def matches(self, other: Descriptor) -> bool: + fields1 = fields(self) + fields2 = fields(other) + if fields1 != fields2: + return False + + for f in fields1: + f1 = getattr(self, f.name) + f2 = getattr(other, f.name) + if f1 is not None and f2 is not None and f1 != f2: + return False + + return True + + +class TestCase(PytorchTestCase): + def assertSameResult( + self, + expected: Callable[[], Any], + actual: Callable[[], Any], + *args, + **kwargs, + ) -> None: + try: + result_e = expected() + exception_e = None + except Exception as e: # noqa: BLE001 + result_e = None + exception_e = e + + try: + result_a = actual() + exception_a = None + except Exception as e: # noqa: BLE001 + result_a = None + exception_a = e + + if (exception_e is None) != (exception_a is None): + if exception_a is not None and exception_e is None: + raise exception_a + self.assertIs( + type(exception_e), + type(exception_a), + f"\n{exception_e=}\n{exception_a=}", + ) + + if exception_e is None: + flattened_e, spec_e = tree_flatten(result_e) + flattened_a, spec_a = tree_flatten(result_a) + + self.assertEqual( + spec_e, + spec_a, + "Both functions must return a result with the same tree structure.", + ) + for value_e, value_a in zip(flattened_e, flattened_a, strict=True): + value_e = _as_interleaved(_as_local(value_e)) + value_a = _as_interleaved(_as_local(value_a)) + + self.assertEqual(value_e, value_a, *args, **kwargs) + + def check_consistency( + self, device: str, dtype, op: OpInfo, variant: Variant + ) -> None: + try: + from .test_complex_tensor import EXTRA_KWARGS, SKIPS + except ImportError: + from test_complex_tensor import EXTRA_KWARGS, SKIPS + test_info = Descriptor( + op=get_overload_packet_from_name(op.name), + device_type=torch.device(device).type, + dtype=dtype, + variant=variant, + ) + for xfail_info, reason in SKIPS.items(): + if xfail_info.matches(test_info): + self.skipTest(reason) + + kwargs = {} + for extra_info, extra_kw in EXTRA_KWARGS.items(): + if extra_info.matches(test_info): + kwargs = extra_kw + break + sample_inputs = op.sample_inputs(device, dtype) + transform_fn = TRANSFORM_FUNCS[variant] + + for sample_input in sample_inputs: + + def expected(sample_input=sample_input): + return op(sample_input.input, *sample_input.args, **sample_input.kwargs) + + subclass_sample = sample_input.transform(transform_fn) + + def actual(subclass_sample=subclass_sample): + return op( + subclass_sample.input, + *subclass_sample.args, + **subclass_sample.kwargs, + ) + + self.assertSameResult(expected, actual, **kwargs) + + +aten = torch.ops.aten + +complex_op_db = tuple( + filter(lambda op: any(op.supports_dtype(ct, "cpu") for ct in COMPLEX_DTYPES), op_db) +) + + +def get_overload_packet_from_name(name: str) -> OpOverloadPacket: + for domain_name in torch.ops: + op_namespace = getattr(torch.ops, domain_name) + op: OpOverloadPacket | None = getattr(op_namespace, name, None) + if op is not None: + return op + + raise RuntimeError(f"No op with {name=} found.") + + +force_test_names = set(map(_get_op_name, FORCE_TEST_LIST)) +implemented_op_names = ( + set(map(_get_op_name, COMPLEX_OPS_TABLE.keys())) - force_test_names +) +implemented_op_db = tuple( + filter(lambda op: op.name in implemented_op_names, complex_op_db) +) +force_test_op_db = tuple(filter(lambda op: op.name in force_test_names, op_db)) + +tested_op_names = {op.name for op in implemented_op_db} | { + op.name for op in force_test_op_db +} +non_tested_ops = { + op for op in COMPLEX_OPS_TABLE if _get_op_name(op) not in tested_op_names +} + + +# TODO (hameerabbasi): There are a number of ops that don't have any associated +# OpInfos. We still need to write tests for those ops. +if len(non_tested_ops) != 0: + import textwrap + import warnings + + list_missing_ops = "\n".join(sorted([str(op) for op in non_tested_ops])) + warnings.warn( + "Not all implemented ops are tested. List of ops missing tests:" + f"\n{textwrap.indent(list_missing_ops, ' ')}", + UserWarning, + stacklevel=2, + ) diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu index 7773210a089ee..f8d87f60d9a2e 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/mv_tensor_accessor_cuda.cu @@ -3,7 +3,11 @@ #include "tensor_accessor_kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp new file mode 100644 index 0000000000000..c560fb0a60af9 --- /dev/null +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/csrc/my_shape.cpp @@ -0,0 +1,20 @@ +#include +#include +#include + +using torch::stable::Tensor; + +torch::headeronly::HeaderOnlyArrayRef my_shape(Tensor t) { + return t.sizes(); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic_2_10, m) { + m.def("my_shape(Tensor t) -> int[]"); +} + +STABLE_TORCH_LIBRARY_IMPL( + libtorch_agnostic_2_10, + CompositeExplicitAutograd, + m) { + m.impl("my_shape", TORCH_BOX(&my_shape)); +} diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py index db1a4fd43033c..a740df8c9e25f 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/libtorch_agnostic_2_10/ops.py @@ -199,6 +199,18 @@ def my_view(t, size) -> Tensor: return torch.ops.libtorch_agnostic_2_10.my_view.default(t, size) +def my_shape(t) -> tuple[int]: + """ + Returns a shape of the input tensor. + + Args: + t: Tensor - input tensor + + Returns: tuple - shape of the imput tensor. + """ + return torch.ops.libtorch_agnostic_2_10.my_shape.default(t) + + def get_any_data_ptr(t, mutable) -> int: """ Return data pointer value of the tensor. diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py index ff2aeff5e932b..405944bc0f9bf 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/setup.py @@ -35,7 +35,6 @@ def get_extension(): extra_compile_args = { "cxx": [ "-fdiagnostics-color=always", - "-DTORCH_STABLE_ONLY", "-DTORCH_TARGET_VERSION=0x020a000000000000", ], } diff --git a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py index a094c57f8e614..05027a41b6715 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py +++ b/test/cpp_extensions/libtorch_agnostic_2_10_extension/test_version_compatibility.py @@ -22,9 +22,15 @@ from pathlib import Path from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase -from torch.utils.cpp_extension import CUDA_HOME, include_paths as torch_include_paths +from torch.utils.cpp_extension import ( + CUDA_HOME, + include_paths as torch_include_paths, + ROCM_HOME, +) +GPU_HOME = CUDA_HOME or ROCM_HOME + # TODO: Fix this error in Windows: # numba.cuda.cudadrv.driver:driver.py:384 Call to cuInit results in CUDA_ERROR_NO_DEVICE if not IS_WINDOWS: @@ -42,8 +48,8 @@ def setUpClass(cls): f"-I{path}" for path in torch_include_paths(device_type="cpu") ] cls.cuda_includes = [] - if CUDA_HOME: - cuda_include_path = os.path.join(CUDA_HOME, "include") + if GPU_HOME: + cuda_include_path = os.path.join(GPU_HOME, "include") if os.path.exists(cuda_include_path): cls.cuda_includes = [f"-I{cuda_include_path}"] @@ -105,13 +111,13 @@ def _compile_cu_file( Compile a CUDA file with TORCH_TARGET_VERSION=2.9.0. Returns (success, error_message). """ - if not CUDA_HOME: - return False, "CUDA_HOME not set" + if not GPU_HOME: + return False, "one of CUDA_HOME and ROCM_HOME should be set but is not" torch_version_2_9 = "0x0209000000000000" cmd = [ - os.path.join(CUDA_HOME, "bin", "nvcc"), + os.path.join(GPU_HOME, "bin", "nvcc" if CUDA_HOME else "hipcc"), "-c", "-std=c++17", f"-DTORCH_TARGET_VERSION={torch_version_2_9}", @@ -120,6 +126,9 @@ def _compile_cu_file( *self.cuda_includes, ] + if ROCM_HOME: + cmd.extend(["-DUSE_ROCM=1"]) + cmd.extend([str(source_file), "-o", str(output_file)]) result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) diff --git a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu index 88c19d0ebf062..1f549630262a6 100644 --- a/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu +++ b/test/cpp_extensions/libtorch_agnostic_2_9_extension/libtorch_agnostic_2_9/csrc/cuda_kernel.cu @@ -1,6 +1,10 @@ #include "kernel.h" +#ifdef USE_ROCM +#include +#else #include +#endif #include #include #include diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp index 3d35b677cd208..3a6f2945d903c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -1,8 +1,275 @@ #include "OpenRegDeviceAllocator.h" +#include "OpenRegFunctions.h" + +#include +#include + +using namespace c10::CachingAllocator; namespace c10::openreg { -static OpenRegDeviceAllocator global_openreg_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); +constexpr size_t kAggregate = static_cast(StatType::AGGREGATE); + + +DeviceMemoryAllocator::DeviceMemoryAllocator(c10::DeviceIndex device_index) + : device_index_(device_index) {} + +void* DeviceMemoryAllocator::malloc(size_t nbytes) { + if (nbytes == 0) { + return nullptr; + } + + std::lock_guard lock(mutex_); + + void* data = nullptr; + auto ret = orMalloc(&data, nbytes); + + TORCH_CHECK( + ret == orSuccess && data != nullptr, + "Failed to allocate ", + nbytes, + " bytes on openreg device ", + device_index_, + ". ", + "Allocated: ", + stats_.allocated_bytes[0].current, + " bytes, ", + "Reserved: ", + stats_.reserved_bytes[0].current, + " bytes"); + + // Track allocation size for proper deallocation statistics + allocation_sizes_[data] = nbytes; + + // Update statistics + stats_.allocated_bytes[kAggregate].increase(nbytes); + stats_.reserved_bytes[kAggregate].increase(nbytes); + stats_.num_device_alloc++; + + return data; +} + +void DeviceMemoryAllocator::free(void* ptr) { + if (!ptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto ret = orFree(ptr); + + if (ret == orSuccess) { + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + size_t nbytes = it->second; + + stats_.allocated_bytes[kAggregate].decrease(nbytes); + stats_.reserved_bytes[kAggregate].decrease(nbytes); + stats_.num_device_free++; + + allocation_sizes_.erase(it); + } else { + TORCH_WARN( + "Successfully freed OpenReg memory pointer ", + ptr, + " on device ", + device_index_, + " that was not tracked by the allocator. " + "Statistics may be inaccurate."); + } + } else { + // orFree failed + auto it = allocation_sizes_.find(ptr); + if (it != allocation_sizes_.end()) { + TORCH_WARN( + "orFree failed for tracked pointer ", + ptr, + " with size ", + it->second, + " bytes on device ", + device_index_, + ". Return code: ", + ret, + ". Keeping tracking record - this may indicate a double-free or invalid pointer."); + } else { + TORCH_WARN( + "orFree failed for untracked pointer ", + ptr, + " on device ", + device_index_, + ". Return code: ", + ret, + ". This likely indicates a double-free or invalid pointer."); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats DeviceMemoryAllocator::getStats() { + std::lock_guard lock(mutex_); + return stats_; +} + +void DeviceMemoryAllocator::resetAccumulatedStats() { + std::lock_guard lock(mutex_); + + // Reset accumulated statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_accumulated(); + stats_.reserved_bytes[stat_type].reset_accumulated(); + stats_.active_bytes[stat_type].reset_accumulated(); + stats_.inactive_split_bytes[stat_type].reset_accumulated(); + stats_.requested_bytes[stat_type].reset_accumulated(); + } + + stats_.num_alloc_retries = 0; + stats_.num_ooms = 0; + stats_.num_sync_all_streams = 0; + stats_.num_device_alloc = 0; + stats_.num_device_free = 0; +} + +void DeviceMemoryAllocator::resetPeakStats() { + std::lock_guard lock(mutex_); + + // Reset peak statistics for all StatTypes + for (const auto stat_type : + c10::irange(static_cast(StatType::NUM_TYPES))) { + stats_.allocated_bytes[stat_type].reset_peak(); + stats_.reserved_bytes[stat_type].reset_peak(); + stats_.active_bytes[stat_type].reset_peak(); + stats_.inactive_split_bytes[stat_type].reset_peak(); + stats_.requested_bytes[stat_type].reset_peak(); + } + + stats_.oversize_allocations.reset_peak(); + stats_.oversize_segments.reset_peak(); +} + +namespace { + +OpenRegDeviceAllocator g_allocator; + +void deleteOpenRegMemory(void* ptr) { + g_allocator.freeMemory(ptr); +} + +} + +OpenRegDeviceAllocator::OpenRegDeviceAllocator() { + std::lock_guard lock(mutex_); + const auto device_count = c10::openreg::device_count(); + device_allocators_.resize(device_count); + for (const auto i : c10::irange(device_count)) { + device_allocators_[i] = std::make_unique(i); + } +} + + +at::DataPtr OpenRegDeviceAllocator::allocate(size_t nbytes) { + int current_device_index = -1; + auto ret = orGetDevice(¤t_device_index); + TORCH_CHECK(ret == orSuccess, "Failed to get current OpenReg device"); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + + void* data = nullptr; + if (nbytes > 0) { + // Allocate memory via device-specific allocator + data = device_allocators_[current_device_index]->malloc(nbytes); + + // Track which device owns this pointer + std::lock_guard lock(mutex_); + allocated_blocks_[data] = current_device_index; + } + + return {data, data, &deleteOpenRegMemory, curr_device}; +} + +at::DeleterFnPtr OpenRegDeviceAllocator::raw_deleter() const { + return &deleteOpenRegMemory; +} + +void OpenRegDeviceAllocator::copy_data( + void* dest, + const void* src, + std::size_t count) const { + auto ret = orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + TORCH_CHECK( + ret == orSuccess, "Failed to copy ", count, " bytes on openreg device"); +} + +bool OpenRegDeviceAllocator::initialized() { + std::lock_guard lock(mutex_); + return !device_allocators_.empty(); +} + +void OpenRegDeviceAllocator::freeMemory(void* ptr) { + if (!ptr) { + return; + } + + // Try to find which device owns this pointer + c10::DeviceIndex device_index = -1; + bool found_in_map = false; + + { + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(ptr); + if (it != allocated_blocks_.end()) { + device_index = it->second; + allocated_blocks_.erase(it); + found_in_map = true; + } + } + + if (found_in_map) { + // Pointer was tracked - free via device-specific allocator with stats + device_allocators_[device_index]->free(ptr); + } else { + // Pointer not tracked - might be already freed by storage or other path + // Try to free it directly via orFree without updating statistics + auto ret = orFree(ptr); + + // Only warn if orFree actually failed (not just "not found") + // In OpenReg's case, orFree returns orErrorUnknown if pointer not in registry + // which is expected for already-freed memory + if (ret != orSuccess && ret != orErrorUnknown) { + TORCH_WARN( + "orFree failed for untracked OpenReg memory pointer ", + ptr, + ". Error code: ", ret); + } + } +} + +c10::CachingDeviceAllocator::DeviceStats OpenRegDeviceAllocator:: + getDeviceStats(c10::DeviceIndex device) { + return device_allocators_[device]->getStats(); +} + +void OpenRegDeviceAllocator::resetAccumulatedStats(c10::DeviceIndex device) { + device_allocators_[device]->resetAccumulatedStats(); +} + +void OpenRegDeviceAllocator::resetPeakStats(c10::DeviceIndex device) { + device_allocators_[device]->resetPeakStats(); +} + +void OpenRegDeviceAllocator::emptyCache(MempoolId_t mempool_id) { + // OpenReg doesn't implement caching yet + // TODO: When caching is implemented, release all free blocks here +} + +void OpenRegDeviceAllocator::recordStream( + const DataPtr& ptr, + c10::Stream stream) { + // OpenReg doesn't track stream usage yet + // TODO: When stream support is added, track which streams are using this pointer +} +// ============ Global Registration ============ + +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &g_allocator); } // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h index c9aea4a913427..777926e02b18c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegDeviceAllocator.h @@ -1,43 +1,78 @@ -#include +#pragma once #include +#include #include +#include #include +#include +#include +#include +#include + namespace c10::openreg { -struct OpenRegDeviceAllocator final : at::Allocator { - OpenRegDeviceAllocator() = default; - - static void ReportAndDelete(void* ptr) { - if (!ptr) { - return; - } - orFreeHost(ptr); - } - - at::DataPtr allocate(size_t nbytes) override { - int current_device_index = -1; - orGetDevice(¤t_device_index); - - auto curr_device = - c10::Device(c10::DeviceType::PrivateUse1, current_device_index); - void* data = nullptr; - if (nbytes > 0) { - orMalloc(&data, nbytes); - TORCH_CHECK( - data, "Failed to allocator ", nbytes, " bytes on openreg device."); - } - return {data, data, &ReportAndDelete, curr_device}; - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } - - void copy_data(void* dest, const void* src, std::size_t count) const final { - orMemcpy(dest, src, count, orMemcpyDeviceToDevice); - } + +class DeviceMemoryAllocator { + public: + explicit DeviceMemoryAllocator(c10::DeviceIndex device_index); + + DeviceMemoryAllocator(const DeviceMemoryAllocator&) = delete; + DeviceMemoryAllocator& operator=(const DeviceMemoryAllocator&) = delete; + + void* malloc(size_t nbytes); + + void free(void* ptr); + + c10::CachingDeviceAllocator::DeviceStats getStats(); + + void resetAccumulatedStats(); + + void resetPeakStats(); + + private: + c10::DeviceIndex device_index_; + + c10::CachingDeviceAllocator::DeviceStats stats_; + + std::unordered_map allocation_sizes_; + + std::recursive_mutex mutex_; +}; + + +class OpenRegDeviceAllocator final : public c10::DeviceAllocator { + public: + OpenRegDeviceAllocator(); + + at::DataPtr allocate(size_t nbytes) override; + at::DeleterFnPtr raw_deleter() const override; + void copy_data(void* dest, const void* src, std::size_t count) const final; + + + bool initialized() override; + void emptyCache(MempoolId_t mempool_id = {0, 0}) override; + void recordStream(const DataPtr& ptr, c10::Stream stream) override; + c10::CachingDeviceAllocator::DeviceStats getDeviceStats( + c10::DeviceIndex device) override; + void resetAccumulatedStats(c10::DeviceIndex device) override; + void resetPeakStats(c10::DeviceIndex device) override; + + + void freeMemory(void* ptr); + + private: + + // Per-device allocators + std::vector> device_allocators_; + + // Global mapping from pointer to device index + std::recursive_mutex mutex_; + ska::flat_hash_map allocated_blocks_; }; -} // namespace c10::openreg + + + +} diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py index 3d67e16a0f503..b4a64eedc5bfc 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_memory.py @@ -1,9 +1,392 @@ # Owner(s): ["module: PrivateUse1"] +import gc +import time + import torch + +import torch_openreg # noqa: F401 from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +class TestDeviceAllocator(TestCase): + """Test cases for OpenRegDeviceAllocator functionality.""" + + def setUp(self): + """Reset memory state before each test.""" + # Force garbage collection to ensure clean state + gc.collect() + # Note: We can't directly reset allocator stats without C++ API, + # but we can ensure tensors are properly released + + def test_basic_allocation(self): + """Test basic memory allocation with various sizes.""" + # Small allocation + x = torch.empty(100, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), 100) + # Large allocation + z = torch.empty(10000, device="openreg") + self.assertEqual(z.device.type, "openreg") + self.assertEqual(z.numel(), 10000) + # Multi-dimensional allocation + w = torch.empty(10, 20, 30, device="openreg") + self.assertEqual(w.device.type, "openreg") + self.assertEqual(w.shape, torch.Size([10, 20, 30])) + + def test_memory_lifecycle(self): + """Test complete memory allocation and deallocation lifecycle.""" + # Allocate tensor + x = torch.empty(1000, device="openreg") + self.assertEqual(x.device.type, "openreg") + + # Explicitly delete tensor + del x + gc.collect() + + # Allocate again to ensure memory was freed + y = torch.empty(1000, device="openreg") + self.assertEqual(y.device.type, "openreg") + del y + gc.collect() + + def test_tensor_copy_operations(self): + """Test memory operations during tensor copies.""" + # CPU to OpenReg + cpu_tensor = torch.randn(100) + openreg_tensor = cpu_tensor.to("openreg") + self.assertEqual(openreg_tensor.device.type, "openreg") + self.assertEqual(cpu_tensor.shape, openreg_tensor.shape) + + # OpenReg to CPU + back_to_cpu = openreg_tensor.to("cpu") + self.assertEqual(back_to_cpu.device.type, "cpu") + self.assertTrue(torch.allclose(cpu_tensor, back_to_cpu)) + + # OpenReg to OpenReg (clone) + cloned = openreg_tensor.clone() + self.assertEqual(cloned.device.type, "openreg") + self.assertTrue(torch.allclose(openreg_tensor.cpu(), cloned.cpu())) + + def test_inplace_operations(self): + """Test memory stability during inplace operations.""" + x = torch.ones(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Inplace addition + x.add_(1) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 2)) + + # Inplace multiplication + x.mul_(2) + self.assertEqual(x.data_ptr(), original_data_ptr) + self.assertTrue(torch.all(x == 4)) + + def test_view_operations(self): + """Test that views share memory correctly.""" + x = torch.randn(100, device="openreg") + original_data_ptr = x.data_ptr() + + # Reshape view + y = x.view(10, 10) + self.assertEqual(y.data_ptr(), original_data_ptr) + self.assertEqual(y.shape, torch.Size([10, 10])) + + # Slice view + z = x[10:20] + # Slices may have different data_ptr but should share storage + self.assertEqual(z.numel(), 10) + + def test_different_dtypes(self): + """Test allocation with different data types.""" + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + x = torch.empty(100, dtype=dtype, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.dtype, dtype) + self.assertEqual(x.numel(), 100) + + def test_tensor_resize(self): + """Test tensor resizing operations.""" + x = torch.empty(100, device="openreg") + _ = x.data_ptr() + + # Resize to smaller size (should reuse storage) + x.resize_(50) + self.assertEqual(x.numel(), 50) + # Storage should still be available + + # Resize to original size + x.resize_(100) + self.assertEqual(x.numel(), 100) + + def test_empty_cache_operation(self): + """Test empty cache functionality.""" + # Allocate some tensors + x = torch.empty(1000, device="openreg") + y = torch.empty(2000, device="openreg") + + # Delete tensors + del x, y + gc.collect() + + # Note: OpenRegDeviceAllocator.emptyCache is currently a no-op + # This test ensures it doesn't crash + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + def test_memory_format_allocation(self): + """Test allocation with different memory formats.""" + # Channels last format + x = torch.empty(2, 3, 4, 4, device="openreg", memory_format=torch.channels_last) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) + + # Contiguous format (default) + y = torch.empty( + 2, 3, 4, 4, device="openreg", memory_format=torch.contiguous_format + ) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.is_contiguous()) + + def test_large_allocation(self): + """Test large memory allocation.""" + # Allocate a large tensor (10MB approximately) + size = 10 * 1024 * 1024 // 4 # 10MB in float32 + x = torch.empty(size, device="openreg") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.numel(), size) + + def test_sequential_allocations_and_deallocations(self): + """Test sequential allocation and deallocation patterns.""" + for i in range(10): + x = torch.empty(1000 + i * 100, device="openreg") + self.assertEqual(x.device.type, "openreg") + # Let tensor go out of scope + del x + gc.collect() + + def test_allocation_with_requires_grad(self): + """Test allocation of tensors with gradient tracking.""" + x = torch.empty(100, device="openreg", requires_grad=True) + self.assertEqual(x.device.type, "openreg") + self.assertTrue(x.requires_grad) + + y = torch.randn(100, device="openreg", requires_grad=True) + self.assertEqual(y.device.type, "openreg") + self.assertTrue(y.requires_grad) + + def test_storage_operations(self): + """Test storage-level operations.""" + x = torch.randn(100, device="openreg") + storage = x.storage() + + # Verify storage is on correct device + self.assertTrue(storage.device.type == "openreg") + + # Verify storage size + self.assertGreaterEqual(storage.size(), x.numel()) + + def test_tensor_from_blob(self): + """Test creating tensors that reference existing memory.""" + x = torch.randn(100, device="openreg") + + # Create a view that references the same data + y = x.view_as(x) + + # They should share the same underlying storage + self.assertEqual(x.data_ptr(), y.data_ptr()) + + # Modifying one should affect the other + x.fill_(5.0) + self.assertTrue(torch.all(y == 5.0)) + + +class TestMemoryLeaks(TestCase): + """Test cases for detecting memory leaks in OpenRegDeviceAllocator.""" + + def setUp(self): + """Reset memory state before each test.""" + gc.collect() + time.sleep(0.1) # Allow time for cleanup + + def test_no_leak_simple_allocations(self): + """Test that simple allocations don't leak memory.""" + # Warm-up + for _ in range(10): + x = torch.empty(1000, device="openreg") + del x + gc.collect() + time.sleep(0.1) + + # Perform many allocations and deallocations + iterations = 1000 + for i in range(iterations): + x = torch.empty(1000, device="openreg") + del x + + if i % 100 == 0: + gc.collect() + + # Final cleanup + gc.collect() + time.sleep(0.1) + + # If there were leaks, this would have accumulated significant memory + # The test passes if no exception/crash occurred + + def test_no_leak_varying_sizes(self): + """Test that allocations of varying sizes don't leak.""" + iterations = 500 + sizes = [100, 500, 1000, 5000, 10000] + + for i in range(iterations): + size = sizes[i % len(sizes)] + x = torch.empty(size, device="openreg") + del x + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_copies(self): + """Test that tensor copies don't leak memory.""" + iterations = 300 + + for i in range(iterations): + # Create tensor + x = torch.randn(500, device="openreg") + + # Copy to CPU + cpu_copy = x.cpu() + + # Copy back to device + device_copy = cpu_copy.to("openreg") + + # Clone + cloned = device_copy.clone() + + # Delete all + del x, cpu_copy, device_copy, cloned + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_views(self): + """Test that tensor views don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.randn(1000, device="openreg") + + # Create various views + view1 = x.view(10, 100) + view2 = x[100:200] + view3 = x.reshape(20, 50) + + # Delete views and original + del view1, view2, view3, x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_inplace_operations(self): + """Test that inplace operations don't leak memory.""" + iterations = 500 + + for i in range(iterations): + x = torch.ones(1000, device="openreg") + + # Multiple inplace operations + x.add_(1) + x.mul_(2) + x.div_(2) + x.sub_(1) + + del x + + if i % 100 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_with_gradients(self): + """Test that tensors with gradients don't leak.""" + iterations = 300 + + for i in range(iterations): + x = torch.randn(100, device="openreg", requires_grad=True) + y = torch.randn(100, device="openreg", requires_grad=True) + + # Operation that creates computation graph + z = x + y + + # Delete all + del x, y, z + + if i % 50 == 0: + gc.collect() + + gc.collect() + time.sleep(0.1) + + def test_no_leak_repeated_large_allocations(self): + """Test repeated large allocations for memory leaks.""" + # Large tensor size (50MB) + size = 50 * 1024 * 1024 // 4 + iterations = 50 + + for i in range(iterations): + x = torch.empty(size, device="openreg") + del x + gc.collect() + time.sleep(0.05) # Allow time for cleanup + + # Final cleanup + gc.collect() + time.sleep(0.1) + + def test_leak_detection_with_statistics(self): + """Test memory leak detection using allocation patterns.""" + # This test verifies that after many alloc/dealloc cycles, + # the allocator properly frees memory + + num_cycles = 10 + allocations_per_cycle = 100 + + for cycle in range(num_cycles): + tensors = [] + + # Allocate many tensors + for i in range(allocations_per_cycle): + t = torch.empty(1000, device="openreg") + tensors.append(t) + + # Verify all allocated + self.assertEqual(len(tensors), allocations_per_cycle) + + # Delete all + tensors.clear() + gc.collect() + time.sleep(0.05) + + # Final verification - if there were leaks, memory would be exhausted + # The test passes if we can still allocate + final_tensor = torch.empty(10000, device="openreg") + self.assertEqual(final_tensor.device.type, "openreg") + del final_tensor + + class TestPinMemory(TestCase): @skipIfTorchDynamo("unsupported aten.is_pinned.default") def test_pin_memory(self): @@ -27,5 +410,110 @@ def test_pin_memory(self): self.assertTrue(pinned_untyped_storage.is_pinned("openreg")) +class TestMultiDeviceAllocation(TestCase): + """Test basic multi-device allocation functionality.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2, "This test requires 2 OpenReg devices") + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_allocation_on_device_1(self): + torch.openreg.set_device(1) + x = torch.empty(100, device="openreg:1") + self.assertEqual(x.device.type, "openreg") + self.assertEqual(x.device.index, 1) + + def test_simultaneous_device_allocations(self): + """Test allocations on both devices simultaneously.""" + x = torch.empty(100, device="openreg:0") + y = torch.empty(200, device="openreg:1") + + self.assertEqual(x.device.index, 0) + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + def test_memory_isolation_between_devices(self): + """Test that memory allocations are isolated between devices.""" + + tensors_dev0 = [torch.empty(1000, device="openreg:0") for _ in range(10)] + tensors_dev1 = [torch.empty(1000, device="openreg:1") for _ in range(10)] + + # Verify all device 0 tensors are on device 0 + for t in tensors_dev0: + self.assertEqual(t.device.index, 0) + + # Verify all device 1 tensors are on device 1 + for t in tensors_dev1: + self.assertEqual(t.device.index, 1) + + # Pointers should be different + ptrs_dev0 = {t.data_ptr() for t in tensors_dev0} + ptrs_dev1 = {t.data_ptr() for t in tensors_dev1} + self.assertEqual( + len(ptrs_dev0 & ptrs_dev1), 0, "Devices should not share pointers" + ) + + def test_alternating_device_allocations(self): + """Test alternating allocations between devices.""" + tensors = [] + for i in range(20): + device_idx = i % 2 + t = torch.empty(100 + i, device=f"openreg:{device_idx}") + self.assertEqual(t.device.index, device_idx) + tensors.append(t) + + # Verify all tensors retained correct device assignment + for i, t in enumerate(tensors): + expected_device = i % 2 + self.assertEqual(t.device.index, expected_device) + + +class TestCrossDeviceOperations(TestCase): + """Test cross-device tensor operations.""" + + def setUp(self): + self.device_count = torch.openreg.device_count() + self.assertEqual(self.device_count, 2) + gc.collect() + + def tearDown(self): + """Restore device 0 to avoid affecting subsequent tests.""" + torch.openreg.set_device(0) + gc.collect() + + def test_tensor_to_different_device(self): + """Test moving tensor from one device to another.""" + # Create on device 0 + x = torch.randn(100, device="openreg:0") + self.assertEqual(x.device.index, 0) + + # Move to device 1 + y = x.to("openreg:1") + self.assertEqual(y.device.index, 1) + self.assertNotEqual(x.data_ptr(), y.data_ptr()) + + # Values should be the same + self.assertTrue(torch.allclose(x.cpu(), y.cpu())) + + def test_bidirectional_device_transfer(self): + """Test transferring tensor back and forth between devices.""" + original = torch.randn(100, device="openreg:0") + original_cpu = original.cpu() + + # 0 -> 1 + on_dev1 = original.to("openreg:1") + self.assertTrue(torch.allclose(original_cpu, on_dev1.cpu())) + + # 1 -> 0 + back_to_dev0 = on_dev1.to("openreg:0") + self.assertTrue(torch.allclose(original_cpu, back_to_dev0.cpu())) + + if __name__ == "__main__": run_tests() diff --git a/test/cpp_extensions/test_libtorch_agnostic.py b/test/cpp_extensions/test_libtorch_agnostic.py index 48ede590cecbf..ef92fc316daa7 100644 --- a/test/cpp_extensions/test_libtorch_agnostic.py +++ b/test/cpp_extensions/test_libtorch_agnostic.py @@ -711,6 +711,15 @@ def test_my_view(self, device): expected_flat = t.view([-1]) self.assertEqual(result_flat, expected_flat) + @skipIfTorchVersionLessThan(2, 10) + def test_my_shape(self, device): + import libtorch_agnostic_2_10 as libtorch_agnostic + + expected = (3, 5) + t = torch.rand(*expected, device=device) + shape = libtorch_agnostic.ops.my_shape(t) + self.assertEqual(shape, expected) + def test_mv_tensor_accessor(self, device): import libtorch_agnostic_2_9 as libtorch_agnostic diff --git a/test/cpp_extensions/torch_stable_test_extension/setup.py b/test/cpp_extensions/torch_stable_test_extension/setup.py deleted file mode 100644 index 062d466e7ae98..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/setup.py +++ /dev/null @@ -1,67 +0,0 @@ -import distutils.command.clean -import shutil -from pathlib import Path - -from setuptools import find_packages, setup - -from torch.utils.cpp_extension import BuildExtension, CppExtension - - -ROOT_DIR = Path(__file__).parent -CSRC_DIR = ROOT_DIR / "torch_stable_test" / "csrc" - - -class clean(distutils.command.clean.clean): - def run(self): - # Run default behavior first - distutils.command.clean.clean.run(self) - - # Remove extension - for path in (ROOT_DIR / "torch_stable_test").glob("**/*.so"): - path.unlink() - # Remove build and dist and egg-info directories - dirs = [ - ROOT_DIR / "build", - ROOT_DIR / "dist", - ROOT_DIR / "torch_stable_test.egg-info", - ] - for path in dirs: - if path.exists(): - shutil.rmtree(str(path), ignore_errors=True) - - -def get_extension(): - extra_compile_args = { - "cxx": ["-fdiagnostics-color=always", "-DTORCH_STABLE_ONLY"], - } - - sources = list(CSRC_DIR.glob("**/*.cpp")) - - return [ - CppExtension( - "torch_stable_test._C", - sources=sorted(str(s) for s in sources), - py_limited_api=True, - extra_compile_args=extra_compile_args, - extra_link_args=[], - ) - ] - - -setup( - name="torch_stable_test", - version="0.0", - author="PyTorch Core Team", - description="Test extension to verify TORCH_STABLE_ONLY flag", - packages=find_packages(exclude=("test",)), - package_data={"torch_stable_test": ["*.dll", "*.dylib", "*.so"]}, - install_requires=[ - "torch", - ], - ext_modules=get_extension(), - cmdclass={ - "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), - "clean": clean, - }, - options={"bdist_wheel": {"py_limited_api": "cp39"}}, -) diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp deleted file mode 100644 index c92d56da11ba3..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/csrc/test_extension.cpp +++ /dev/null @@ -1 +0,0 @@ -#include // This should trigger the TORCH_STABLE_ONLY error diff --git a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py b/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py deleted file mode 100644 index 5c5613bb5484e..0000000000000 --- a/test/cpp_extensions/torch_stable_test_extension/torch_stable_test/test_torch_stable.py +++ /dev/null @@ -1,22 +0,0 @@ -# Owner(s): ["module: cpp"] - -from pathlib import Path - -from torch.testing._internal.common_utils import ( - install_cpp_extension, - IS_WINDOWS, - run_tests, - TestCase, -) - - -if not IS_WINDOWS: - - class TestTorchStable(TestCase): - def test_setup_fails(self): - with self.assertRaisesRegex(RuntimeError, "build failed for cpp extension"): - install_cpp_extension(extension_root=Path(__file__).parent.parent) - - -if __name__ == "__main__": - run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index ad3064608960d..076c4de69f44f 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -428,7 +428,14 @@ def test_manual_reshard_with_reshard_after_forward_false(self): @xfailIf(TEST_XPU) # https://github.com/intel/torch-xpu-ops/issues/1571 def test_set_reduce_scatter_divide_factor(self): self.run_subtests( - {"divide_factor": [self.world_size * 2, self.world_size]}, + { + "divide_factor": [self.world_size * 2, self.world_size], + "mesh_shape": [ + (self.world_size,), + (self.world_size // 2, 2), + (self.world_size, 1), + ], + }, self._test_set_reduce_scatter_divide_factor, ) self.run_subtests( @@ -436,18 +443,31 @@ def test_set_reduce_scatter_divide_factor(self): self._test_set_reduce_scatter_divide_factor_mixed_prevision, ) - def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): + def _test_set_reduce_scatter_divide_factor( + self, divide_factor: float, mesh_shape: tuple[int] | tuple[int, int] + ): torch.manual_seed(42) model_args = ModelArgs(dropout_p=0.0, weight_tying=False) model = Transformer(model_args) ref_model = copy.deepcopy(model).to(device_type) ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + mesh_dim_names = ("outer",) if len(mesh_shape) == 1 else ("outer", "inner") + mesh = init_device_mesh( + device_type.type, mesh_shape, mesh_dim_names=mesh_dim_names + ) for module in model.modules(): if isinstance(module, TransformerBlock): - fully_shard(module, reshard_after_forward=False) - model = fully_shard(model, reshard_after_forward=False) + fully_shard(module, reshard_after_forward=False, mesh=mesh) + model = fully_shard(model, reshard_after_forward=False, mesh=mesh) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) + + # Get ref_model params which should have the specific division factor applied + block_params = set() + for ref_mod in ref_model.modules(): + if isinstance(ref_mod, TransformerBlock): + block_params.update(ref_mod.parameters()) + non_block_params = set(ref_model.parameters()) - block_params torch.manual_seed(42 + self.rank) inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type) @@ -456,16 +476,18 @@ def _test_set_reduce_scatter_divide_factor(self, divide_factor: float): ref_loss = ref_model(inp).sum() ref_loss.backward() for param in ref_model.parameters(): - param.grad.mul_(1.0 / divide_factor) + factor = divide_factor if param in non_block_params else self.world_size + param.grad.mul_(1.0 / factor) dist.all_reduce(param.grad) loss = model(inp).sum() loss.backward() ref_optim.step() optim.step() - ref_optim.zero_grad() - optim.zero_grad() self.assertEqual(ref_loss, loss) + # Check parity before calling zero_grad so that grads are also checked check_sharded_parity(self, ref_model, model) + ref_optim.zero_grad() + optim.zero_grad() def _test_set_reduce_scatter_divide_factor_mixed_prevision( self, divide_factor: float @@ -484,7 +506,7 @@ def _test_set_reduce_scatter_divide_factor_mixed_prevision( fully_shard(mlp, mp_policy=mp_policy) model = fully_shard(model, mp_policy=mp_policy) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) - model.set_reduce_scatter_divide_factor(divide_factor) + model.set_gradient_divide_factor(divide_factor) torch.manual_seed(42 + self.rank) inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype) diff --git a/test/distributed/tensor/debug/test_comm_mode.py b/test/distributed/tensor/debug/test_comm_mode.py index c87164750c684..d122a9f716fcd 100644 --- a/test/distributed/tensor/debug/test_comm_mode.py +++ b/test/distributed/tensor/debug/test_comm_mode.py @@ -6,7 +6,7 @@ import torch.nn as nn from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_distributed import requires_nccl +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule from torch.testing._internal.distributed.fake_pg import FakeStore @@ -14,6 +14,9 @@ c10d_functional = torch.ops.c10d_functional c10d_ops = torch.ops.c10d +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) class TestCommMode(TestCase): @@ -28,7 +31,7 @@ def setUp(self): dist.init_process_group( backend="fake", rank=1, world_size=self.world_size, store=store ) - self.device_type = "cuda" if torch.cuda.is_available() else "cpu" + self.device_type = device_type self.world_pg = dist.distributed_c10d._get_default_group() def checksAssert(self, comm_mode, key, expected_value, expected_total_value): @@ -111,12 +114,12 @@ def f(x, y): self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1) self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0) - @requires_nccl() + @requires_accelerator_dist_backend(["nccl", "xccl"]) def test_comm_mode_with_c10d(self): - if not torch.cuda.is_available(): + if not torch.accelerator.is_available(): return - inp = torch.rand(2, 8, 16).cuda() + inp = torch.rand(2, 8, 16).to(device_type) all_gather_out = inp.new_empty(self.world_size * 2, 8, 16) comm_mode = CommDebugMode() diff --git a/test/distributed/tensor/test_dtensor.py b/test/distributed/tensor/test_dtensor.py index e99734c6b8437..c47ff79091493 100644 --- a/test/distributed/tensor/test_dtensor.py +++ b/test/distributed/tensor/test_dtensor.py @@ -658,11 +658,11 @@ def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor): @with_comms def test_dtensor_device_mesh_device_conversion(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh = self.build_device_mesh() - # construct from a cpu local tensor with cuda device mesh - # should automatically convert the dist tensor to cuda + # construct from a cpu local tensor with gpu device mesh + # should automatically convert the dist tensor to gpu placements = [Shard(0)] local_tensor = torch.randn(3, 3) dist_tensor = DTensor.from_local(local_tensor, mesh, placements) @@ -711,7 +711,7 @@ def test_dtensor_api_device_mesh_context_manager(self): @with_comms def test_dtensor_2d_mesh(self): mesh_tensor = torch.arange(self.world_size).reshape(2, 4) - # construct a cuda device mesh + # construct a gpu device mesh mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 2d device mesh and test if works @@ -733,7 +733,7 @@ def test_dtensor_2d_mesh(self): @with_comms def test_device_mesh_nd(self): - # construct a cuda device mesh + # construct a gpu device mesh mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) mesh = DeviceMesh(self.device_type, mesh_tensor) # construct a dist tensor on 3d device mesh and test if works @@ -1064,8 +1064,8 @@ def _create_tensor(self, size): # Keep everything deterministic. torch.manual_seed(0) tensor = torch.rand(size) - if self.device_type == "cuda": - return tensor.cuda() + if self.device_type != "cpu": + return tensor.to(self.device_type) else: return tensor diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index ddba3150b05fb..e58b6dda658f3 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -39,6 +39,7 @@ RowwiseParallel, ) from torch.distributed.tensor.placement_types import _StridedShard +from torch.testing._internal.common_device_type import skipXPUIf from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import get_devtype from torch.testing._internal.common_utils import ( @@ -47,8 +48,6 @@ run_tests, skipIfHpu, skipIfTorchDynamo, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -64,6 +63,54 @@ dev_type = torch.device(get_devtype()) +class PytreeTuple: + """ + Tuple-like values that are treated as leaves of a PyTree. + """ + + def __init__(self, *values): + self._values = tuple(values) + + def __repr__(self): + pr = repr(self._values)[1:-1] + return f"{type(self).__name__}({pr})" + + def __getitem__(self, i): + return self._values[i] + + def __iter__(self): + return iter(self._values) + + def __len__(self): + return len(self._values) + + def __eq__(self, other: object) -> bool: + if isinstance(other, self.__class__): + return self._values == other._values + elif isinstance(other, tuple): + return self._values == other + return False + + def __hash__(self) -> int: + return hash(self._values) + + def __add__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*self, *other) + raise NotImplementedError(type(other)) + + def __radd__(self, other): + if isinstance(other, (self.__class__, tuple)): + return self.__class__(*other, *self) + raise NotImplementedError(type(other)) + + def index(self, value): + return self._values.index(value) + + def count(self, value): + return self._values.count(value) + + class SimpleModel(nn.Module): def __init__(self, device): super().__init__() @@ -95,6 +142,10 @@ def extract_graph(fx_g, _, graph_cell): partition_fn=min_cut_rematerialization_partition, ) +device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" +) + def _apply_sharding(mod: nn.Module, shard_dim: int, device_mesh: DeviceMesh): """ @@ -141,7 +192,7 @@ def tearDown(self): @property def device_type(self) -> str: - return "cuda" if TEST_CUDA else "hpu" if TEST_HPU else "cpu" + return device_type @property def world_size(self) -> int: @@ -160,9 +211,9 @@ def fn(x): res = fn(x) res.to_local().sum().backward() - @unittest.skipIf(not TEST_CUDA, "CUDA not available") + @unittest.skipIf(not torch.accelerator.is_available(), "accelerator not available") def test_dtensor_basic_export(self): - mesh = DeviceMesh("cuda", torch.arange(self.world_size)) + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) param = torch.randn(4, 4) param_x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False) @@ -188,10 +239,10 @@ def forward(self, x): ) self.assertExpectedInline( str(ep.graph_module.code).strip(), - """\ + f"""\ def forward(self, b_buffer, x): _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None - to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None + to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}')); x = None view_as = torch.ops.aten.view_as.default(to, to); to = None dtensor___init__0 = self.dtensor___init__0 dtensor_const_func_spec0 = self.dtensor_const_func_spec0 @@ -206,10 +257,10 @@ def forward(self, b_buffer, x): # add is performed in _propagate_tensor_meta_non_cached, hence add_1 instead of add self.assertExpectedInline( str(ep.run_decompositions({}).graph_module.code).strip(), - """\ + f"""\ def forward(self, b_parametrizations_buffer_original0, x): _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None - _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None + _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='{self.device_type}', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None @@ -377,6 +428,7 @@ def fn(x): self.assertEqual(res, ref) @skipIfHpu + @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/1981") def test_dtensor_dynamic_loss_parallel_log_softmax(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -763,6 +815,37 @@ def fn(x): # this fails with an inductor stride assert out_dt.to_local().sum().backward() + def test_dynamo_to_local_grad_placements_sequence(self): + placements = PytreeTuple([Shard(0)]) + + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + + def test_dynamo_to_local_grad_placements_sequence_intermediate(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + def fn(x): + placements = PytreeTuple([Shard(0)]) + return dt.to_local(grad_placements=placements) + 2 + + fn_opt = torch.compile(fn, backend="aot_eager", fullgraph=True) + x = torch.ones(4) + dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False) + + out_ref = fn(dt) + out_test = fn_opt(dt) + self.assertEqual(out_ref, out_test) + def test_dynamo_to_local_kwargs(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) @@ -815,13 +898,13 @@ def fn(x, y, z): out = layer_norm.permute(0, 2, 1) return out - x = torch.randn(4, 2, 4, requires_grad=True, device="cuda") + x = torch.randn(4, 2, 4, requires_grad=True, device=self.device_type) x_dt = DTensor.from_local(x, mesh, [Shard(1)], run_check=False) - y = torch.randn(4, requires_grad=True, device="cuda") + y = torch.randn(4, requires_grad=True, device=self.device_type) y_dt = DTensor.from_local(y, mesh, [Replicate()], run_check=False) - z = torch.randn(4, requires_grad=True, device="cuda") + z = torch.randn(4, requires_grad=True, device=self.device_type) z_dt = DTensor.from_local(z, mesh, [Replicate()], run_check=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) @@ -919,7 +1002,7 @@ def test_dtensor_dynamo_device_mesh_attrs(self): # pass in tensor as inputs/outputs, create DTensor and run redistribute # (allgather collective) inside the fn def fn(x_dt): - if x_dt.device_mesh.device_type == "cuda": + if x_dt.device_mesh.device_type == f"{self.device_type}": return x_dt + 1 else: return x_dt + 2 @@ -1051,7 +1134,7 @@ def forward(self, input): model = FakeTransformer().to(self.device_type) - tp_mesh = init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) + tp_mesh = init_device_mesh(self.device_type, (2,), mesh_dim_names=("tp",)) # apply sequence parallel parallel_plan = { diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index 139f5fb61fac8..72d95efcfa8c9 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -34,7 +34,11 @@ register_op_strategy, replicate_op_strategy, ) -from torch.distributed.tensor.debug import CommDebugMode +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, + CommDebugMode, +) from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -479,7 +483,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool: @@ -645,6 +650,28 @@ def test_call_with_different_nontensor_args(self): self.assertEqual(out1.full_tensor(), out2.full_tensor()) +class TestStrategyOperation(DTensorTestBase): + @property + def world_size(self): + return 2 + + @with_comms + def test_cache_clean(self): + mesh = self.build_device_mesh() + test_op = torch.ops.mylib.numpy_sin + x = torch.randn(2, device=self.device_type) + y = torch.randn(2, device=self.device_type) + x_dt = distribute_tensor(x, mesh, [Shard(0)]) + y_dt = distribute_tensor(y, mesh, [Shard(0)]) + with op_strategy_context(test_op.default, replicate_op_strategy): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + with self.assertRaisesRegex( + NotImplementedError, + f"Operator {test_op.default} does not have a sharding strategy registered", + ): + self._test_op_on_dtensor(test_op, x_dt, y_dt) + + DistTensorReplicateStrategyRegistrationTestWithLocalTensor = ( create_local_tensor_test_class( DistTensorReplicateStrategyRegistrationTest, diff --git a/test/distributed/tensor/test_random_ops.py b/test/distributed/tensor/test_random_ops.py index 61b88ee169e2e..4ff470511f2ad 100644 --- a/test/distributed/tensor/test_random_ops.py +++ b/test/distributed/tensor/test_random_ops.py @@ -6,8 +6,8 @@ import torch import torch.distributed._functional_collectives as funcol import torch.distributed.tensor._random as random +from torch.distributed._local_tensor import LocalTensor, maybe_run_for_local_tensor from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.distributed_c10d import broadcast_object_list from torch.distributed.fsdp import fully_shard from torch.distributed.tensor import ( DeviceMesh, @@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, skip_if_lt_x_gpu, skip_unless_torch_gpu, @@ -34,9 +35,12 @@ from torch.utils._typing_utils import not_none -def get_generator_seed_for_device_type(device_type: str) -> int: - device_module = torch.get_device_module(device_type) - return device_module.get_rng_state()[:8].view(torch.int64).item() +def get_generator_seed_for_device_type(device_type: str): + from torch.distributed._local_tensor import ( + get_generator_seed_for_device_type as _get_seed, + ) + + return _get_seed(device_type) class DistTensorRandomInitTest(DTensorTestBase): @@ -134,9 +138,6 @@ def test_meta_tensor_init(self): torch.empty(*size, device="meta"), device_mesh, [Replicate()] ) - # the tensor slice on the current rank - self_slice = slice(1024 * self.rank, 1024 * self.rank + 1024) - # Test 1: enable the distribute region for RNG (by default) self.assertTrue(meta_dtensor.is_meta) # Tensor meta init @@ -150,16 +151,23 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same because they're replicated - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - gathered_local_tensors[self_slice, :], - gathered_local_tensors[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(gathered_local_tensors, rank): + # the tensor slice on the current rank + self_slice = slice(1024 * rank, 1024 * rank + 1024) + + # compare with local tensors from other ranks + for other_rank in range(self.world_size): + # the RNG result on each rank are the same because they're replicated + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) + self.assertEqual( + gathered_local_tensors[self_slice, :], + gathered_local_tensors[other_slice, :], + ) + + compute_rankwise_if_local_tensor(gathered_local_tensors.wait(), self.rank) # Test 2: disable the distribute region for RNG self.assertTrue(meta_dtensor.is_meta) @@ -175,15 +183,7 @@ def test_meta_tensor_init(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - for other_rank in range(self.world_size): - # the RNG result on each rank are the same even without the help of DTensor's RNG infra, - # since the default RNG is the same across ranks. - if self.rank != other_rank: - other_slice = slice(1024 * other_rank, 1024 * other_rank + 1024) - self.assertEqual( - local_tensor[self_slice, :], local_tensor[other_slice, :] - ) + compute_rankwise_if_local_tensor(local_tensor.wait(), self.rank) @with_comms @skip_unless_torch_gpu @@ -224,13 +224,17 @@ def test_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -277,13 +281,17 @@ def test_fsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - for other_rank in range(self.world_size): - if self.rank != other_rank: - self.assertNotEqual( - weight_local, - weight_gather[other_rank : other_rank + 1, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, weight_gather, rank): + # verify the weights are initialized differently on all ranks + for other_rank in range(self.world_size): + if rank != other_rank: + self.assertNotEqual( + weight_local, + weight_gather[other_rank : other_rank + 1, :], + ) + + compute_rankwise_if_local_tensor(weight_local, weight_gather.wait(), self.rank) class DistTensorRandomOpTest(DTensorTestBase): @@ -291,9 +299,14 @@ class DistTensorRandomOpTest(DTensorTestBase): @skip_unless_torch_gpu def test_rng_tracker_init(self): torch.manual_seed(self.rank) - object_list = [torch.initial_seed()] - broadcast_object_list(object_list) - seed_from_rank_0 = int(object_list[0]) + seed_local = ( + torch.zeros_like(torch.empty(1), device=self.device_type) + + torch.initial_seed() + ) + torch.distributed.broadcast(seed_local, src=0) + # if localtensor, it should automaticall reconcile after the broadcast + # since all virtual ranks should have rank 0's initial_seed() + seed_from_rank_0 = seed_local device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) # seed synchronization now does NOT happen after the first `distribute_tensor` @@ -344,15 +357,19 @@ def test_manual_seed(self): @with_comms @skip_unless_torch_gpu def test_manual_seed_submesh(self): - # the current rank is not a part of the mesh - single_rank_device_mesh = DeviceMesh( - self.device_type, [(self.rank + 1) % self.world_size] - ) - with self.assertRaisesRegex( - RuntimeError, - "manual_seed requires the current rank to be a part of the device mesh", - ): - manual_seed(self.rank, single_rank_device_mesh) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(rank): + # the current rank is not a part of the mesh + single_rank_device_mesh = DeviceMesh( + self.device_type, [(rank + 1) % self.world_size], _rank=rank + ) + with self.assertRaisesRegex( + RuntimeError, + "manual_seed requires the current rank to be a part of the device mesh", + ): + manual_seed(rank, single_rank_device_mesh) + + compute_rankwise_if_local_tensor(self.rank) @with_comms @skip_unless_torch_gpu @@ -394,7 +411,7 @@ def test_pipeline_parallel_manual_seed(self): for other_rank in range(self.world_size): if self.rank != other_rank: self.assertNotEqual( - spmd_dtensor.to_local(), + spmd_dtensor, tensor_gather[2 * other_rank : 2 * (other_rank + 1), :], ) @@ -428,16 +445,20 @@ def test_deterministic_dropout_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_unless_torch_gpu @@ -454,16 +475,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have a different local tensor for shard placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertNotEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertNotEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) # we should set manual seed to the same value on all SPMD ranks torch.manual_seed(0) @@ -472,16 +497,20 @@ def test_deterministic_rand_1d(self): dtensor.to_local(), gather_dim=0, group=(device_mesh, 0) ) - # compare with local tensors from other ranks - self_slice = slice(4 * self.rank, 4 * self.rank + 4) - for other_rank in range(self.world_size): - if self.rank != other_rank: - # other rank should have an identical local tensor for replicate placement - other_slice = slice(4 * other_rank, 4 * other_rank + 4) - self.assertEqual( - local_tensor[self_slice, :], - local_tensor[other_slice, :], - ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(local_tensor, rank): + # compare with local tensors from other ranks + self_slice = slice(4 * rank, 4 * rank + 4) + for other_rank in range(self.world_size): + if rank != other_rank: + # other rank should have an identical local tensor for replicate placement + other_slice = slice(4 * other_rank, 4 * other_rank + 4) + self.assertEqual( + local_tensor[self_slice, :], + local_tensor[other_slice, :], + ) + + compute_rankwise_if_local_tensor(local_tensor, self.rank) @with_comms @skip_if_lt_x_gpu(4) @@ -539,7 +568,12 @@ def test_deterministic_uniform_2d(self): shard_linear_idx = random._rng_tracker._calc_shard_linear_idx( shard_coord, shard_size ) - self.assertEqual(shard_linear_idx, shard_index[self.rank]) + + @maybe_run_for_local_tensor + def check_shard_index(shard_linear_idx, rank): + self.assertEqual(shard_linear_idx, shard_index[rank]) + + check_shard_index(shard_linear_idx, self.rank) # compute local size and offset _, local_shard_offset = compute_local_shape_and_global_offset( @@ -578,16 +612,27 @@ def test_deterministic_uniform_2d(self): # allgather the local tensors full_tensor = dtensor.full_tensor() - # compare local tensor with each other shard - for other_local_shard in local_shard_comb: - other_local_shard_offset, _ = zip(*other_local_shard) - slice_idx = [ - slice(offset, offset + size) for offset, size in other_local_shard - ] - if local_shard_offset == other_local_shard_offset: - self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) - else: - self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + full_tensor = ( + full_tensor.reconcile() + if isinstance(full_tensor, LocalTensor) + else full_tensor + ) + + @maybe_run_for_local_tensor + def blockwise_iter_if_localtensor(local_tensor, local_shard_offset): + # compare local tensor with each other shard + for other_local_shard in local_shard_comb: + other_local_shard_offset, _ = zip(*other_local_shard) + slice_idx = [ + slice(offset, offset + size) + for offset, size in other_local_shard + ] + if local_shard_offset == other_local_shard_offset: + self.assertEqual(full_tensor[tuple(slice_idx)], local_tensor) + else: + self.assertNotEqual(full_tensor[tuple(slice_idx)], local_tensor) + + blockwise_iter_if_localtensor(local_tensor, local_shard_offset) class DistTensorRandomOpsTest3D(DTensorTestBase): @@ -641,22 +686,46 @@ def test_hsdp_tp_model_meta_init(self): group=WORLD, ) - # verify the weights are initialized differently on all ranks - shard_dim_0_len = self.world_size // 4 - for other_rank in range(self.world_size): - other_rank_dim_0_start = other_rank * shard_dim_0_len - other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len - if self.rank % 4 != other_rank % 4: - self.assertNotEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) - else: - self.assertEqual( - weight_local, - weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], - ) + weight_gather = weight_gather.wait() + + weight_gather = ( + weight_gather.reconcile() + if isinstance(weight_gather, LocalTensor) + else weight_gather + ) + @maybe_run_for_local_tensor + def compute_rankwise_if_local_tensor(weight_local, rank): + # verify the weights are initialized differently on all ranks + shard_dim_0_len = self.world_size // 4 + for other_rank in range(self.world_size): + other_rank_dim_0_start = other_rank * shard_dim_0_len + other_rank_dim_0_end = other_rank_dim_0_start + shard_dim_0_len + if rank % 4 != other_rank % 4: + self.assertNotEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + else: + self.assertEqual( + weight_local, + weight_gather[other_rank_dim_0_start:other_rank_dim_0_end, :], + ) + + compute_rankwise_if_local_tensor(weight_local, self.rank) + + +DistTensorRandomInitTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomInitTest, +) + +DistTensorRandomOpTestWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpTest, +) + +DistTensorRandomOpsTest3DWithLocalTensor = create_local_tensor_test_class( + DistTensorRandomOpsTest3D, +) if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 381660e47927d..86bb567a39616 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -27,8 +27,6 @@ instantiate_parametrized_tests, parametrize, run_tests, - TEST_CUDA, - TEST_HPU, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -541,7 +539,7 @@ def test_redistribute_shard_dim_change(self, dtype): local_out_dt = out_dt.to_local() local_expected_dt = expected_dt.to_local() self.assertEqual(out_dt.to_local(), expected_dt.to_local()) - if TEST_HPU or TEST_CUDA: + if torch.accelerator.is_available(): self.assertEqual( comm_mode.get_comm_counts()[ torch.ops._dtensor.shard_dim_alltoall diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 80968fb52e904..4748db4f7377b 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -296,8 +296,8 @@ def test_zeros_like(self): self.assertEqual(dist_tensor.dtype, torch.float32) self.assertEqual(zeros_like_dt.dtype, torch.bfloat16) - @with_comms @skip_if_lt_x_gpu(4) + @with_comms def test_stack(self): mesh_2d = DeviceMesh( self.device_type, torch.arange(self.world_size).reshape(2, 2) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 426f77e379f8f..a60d3868e4f82 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -30,7 +30,7 @@ from torch.testing._internal.inductor_utils import HAS_GPU -def estimate_aten_runtime(fx_node, compute_multiplier=1.0): +def estimate_aten_runtime(fx_node, override_size=None, compute_multiplier=1.0): # for tests, assume a matmul can hide a single collective if "c10" in str(fx_node.target): return 1.0 @@ -1112,7 +1112,7 @@ def test_multiple_hiding_nodes_bucketing(self): # Use 0.5 compute multiplier so each collective needs 2 matmuls to be fully hidden def estimate_with_half_compute(fx_node, override_size=None): - return estimate_aten_runtime(fx_node, compute_multiplier=0.5) + return estimate_aten_runtime(fx_node, override_size, compute_multiplier=0.5) def func(a, b, *, ranks): # Two all_gathers that will be hidden by multiple compute operations @@ -1162,6 +1162,56 @@ def func(a, b, *, ranks): correct = func(a, b, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_bucketing_with_convert_dtype(self): + """Test that all_gathers with dtype conversion get bucketed and produce correct results.""" + + def func(a, b, c, d, *, ranks): + # Convert inputs to float16 before all_gather + a_fp16 = a.to(torch.float16) + b_fp16 = b.to(torch.float16) + + # Two all_gathers with converted dtypes + ag1 = _functional_collectives.all_gather_tensor(a_fp16, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b_fp16, 0, ranks) + + # same dtype + ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) + ag4 = _functional_collectives.all_gather_tensor(d, 0, ranks) + + return ag1, ag2, ag3, ag4 + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(4, 4, dtype=torch.float32, device=device_type) + b = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 2 + c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3 + d = torch.ones(4, 4, dtype=torch.float64, device=device_type) * 4 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c, d) + + # Should have 1 bucketed all_gather (both ag1 and ag2 bucketed together) + FileCheck().check_count( + "torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True + ).run(aten_graph_str) + + # Verify convert_element_type ops are removed (dtype conversion handled by _pre_bucket_all_gather) + FileCheck().check_not("torch.ops.prims.convert_element_type").run( + aten_graph_str + ) + + # Verify correctness - this tests that dtype conversion is handled correctly + correct = func(a, b, c, d, ranks=ranks) + self.assertTrue(same(out, correct)) + def get_toy_model(device_type: str): """ diff --git a/test/distributed/test_debug.py b/test/distributed/test_debug.py new file mode 100644 index 0000000000000..ff6a203bcf160 --- /dev/null +++ b/test/distributed/test_debug.py @@ -0,0 +1,56 @@ +# Owner(s): ["oncall: distributed"] + +import os + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +import torch +import torch.distributed as dist +from torch.distributed.debug import start_debug_server, stop_debug_server +from torch.testing._internal.common_utils import run_tests, TestCase + + +session = requests.Session() +retry_strategy = Retry(total=5, backoff_factor=0.5) +adapter = HTTPAdapter(max_retries=retry_strategy) +session.mount("http://", adapter) +session.mount("https://", adapter) + + +class TestDebug(TestCase): + def test_basics(self) -> None: + store = dist.TCPStore("localhost", 0, 1, is_master=True, wait_for_workers=False) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(store.port) + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + + port = 25999 + + def fetch(path: str) -> str: + resp = session.get(f"http://localhost:{port}{path}") + resp.raise_for_status() + return resp.text + + start_debug_server(port=port) + + self.assertIn("torch profiler", fetch("/")) + self.assertIn("View 0", fetch("/profile?duration=0.01")) + self.assertIn("test_basics", fetch("/stacks")) + self.assertIn("pg_status", fetch("/fr_trace")) + + if torch.cuda.is_available(): + self.assertIn("pg_status", fetch("/fr_trace_nccl")) + + # test errors + resp = session.get(f"http://localhost:{port}/blah") + self.assertEqual(resp.status_code, 404) + self.assertIn("Handler not found: /blah", resp.text) + + stop_debug_server() + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/test_nvshmem_triton.py b/test/distributed/test_nvshmem_triton.py index 3fec9a01f049c..ad30a7df5d43a 100644 --- a/test/distributed/test_nvshmem_triton.py +++ b/test/distributed/test_nvshmem_triton.py @@ -12,7 +12,6 @@ import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem from torch._inductor.runtime.triton_compat import triton from torch.distributed._symmetric_memory._nvshmem_triton import requires_nvshmem -from torch.testing._internal.common_cuda import SM100OrLater from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -265,10 +264,6 @@ def my_reduce_kernel( nvshmem.reduce(team_handle, dest_tensor, source_tensor, nreduce, operation) -@skip_but_pass_in_sandcastle_if( - SM100OrLater, - "Skipping all NVSHMEM Triton tests due to https://github.com/pytorch/pytorch/issues/162897", -) @instantiate_parametrized_tests class NVSHMEMTritonTest(MultiProcContinuousTest): def _init_device(self) -> None: diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py index de6f2ba612977..c0c4c31cc1a81 100644 --- a/test/distributed/test_overlap_bucketing_unit.py +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -667,6 +667,94 @@ def func(a, b): str(traced.graph) ) + def test_can_bucket_with_convert_dtype_as_hiding_nodes(self): + """ + Test that all_gathers can bucket when convert_element_type ops ARE the hiding nodes. + + Graph structure: + ag1_start -> convert1 (hides ag1) -> ag1_wait -> ag2_start -> convert2 (hides ag2) -> ag2_wait + + The convert_element_type ops ARE hiding nodes - no matmuls. + This tests that dependencies are transferred correctly when convert nodes are erased. + """ + + def func(a, b, c): + group_name = "0" + group_size = 1 + + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + b = torch.ops.prims.convert_element_type.default(b, torch.float16) + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, group_name + ) + ag3 = torch.ops._c10d_functional.all_gather_into_tensor( + c, group_size, group_name + ) + + mm = ag1_out @ ag1_out + + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + ag3_out = torch.ops._c10d_functional.wait_tensor(ag3) + + return ag1_out, ag2_out, ag3_out, mm + + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device, dtype=torch.float32) + b = torch.ones(4, 4, device=self.device, dtype=torch.float32) + c = torch.ones(4, 4, device=self.device, dtype=torch.float32) + + traced = make_fx(func)(a, b, c) + + # Find nodes + ag1, ag2, ag3 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + convert1 = traced.graph.find_nodes( + op="call_function", + target=torch.ops.prims.convert_element_type.default, + )[0] + mm = traced.graph.find_nodes( + op="call_function", + target=torch.ops.aten.mm.default, + )[0] + + hiding_annotations = { + ag1: convert1, + ag2: mm, + ag3: mm, + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + graph_str = str(traced.graph) + + f = FileCheck() + f.check_count("%all_gather_into_tensor", 1, exactly=True) + f.check("pre_bucket_all_gather").check("wait_tensor").check( + "%all_gather_into_tensor_out" + ).run(graph_str) + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 3b4aff724eee4..967bedb9ebaae 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -585,6 +585,10 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): # Annotation: {'stream': 1} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(2, 1); record_event_default = None + wait_event_default = torch.ops.streams.wait_event.default(2, 0); wait_event_default = None + # Annotation: {'stream': 0} add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add_3, add_2) diff --git a/test/export/test_converter.py b/test/export/test_converter.py index e739e5c346677..5b608503a1168 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1405,7 +1405,7 @@ def func3(x): # noqa: F841 ) # qnnpack not supported on s390x @xfailIfS390X - def test_ts2ep_convert_quantized_model(self): + def test_ts2ep_convert_quantized_model1(self): class Standalone(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 9cf442c27a2bb..866eeaaee3986 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -640,16 +640,13 @@ def forward(self, x): self.assertExpectedInline( without_token_ep.graph_module.code.strip(), """\ -def forward(self, token, obj_attr, x): - with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None - getitem = with_effects[0] - getitem_1 = with_effects[1] - getitem_2 = with_effects[2]; with_effects = None +def forward(self, obj_attr, x): + takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(foo = obj_attr, x = x); x = None + getitem_1 = takes_foo_tuple_return_default[0] + getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None - with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None - getitem_3 = with_effects_1[0] - getitem_4 = with_effects_1[1]; with_effects_1 = None - return (getitem_3, getitem_4)""", # noqa: B950 + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None + return (takes_foo_default,)""", # noqa: B950 ) def test_fakify_script_objects(self): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 246122433e06c..adf0986811648 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -461,9 +461,9 @@ def forward(self, x): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) attr = self.attr _guards_fn = self._guards_fn(x); _guards_fn = None - takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) - takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None - add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None + takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) + takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None + add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 ) self.assertExpectedInline( @@ -1087,10 +1087,12 @@ def forward(self, token, tq, x): str(ep.graph_module.graph).strip(), """\ graph(): + %token : [num_users=1] = placeholder[target=token] %tq : [num_users=2] = placeholder[target=tq] %x : [num_users=1] = placeholder[target=x] - %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {}) - return (tq,)""", # noqa: B950 + %with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {}) + %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) + return (getitem, tq)""", # noqa: B950 ) def test_deepcopy(self): diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 2c4cf02bc1c8a..e995959afba47 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -870,6 +870,100 @@ def forward(self, primals_2, getitem_1, tangents_1, tangents_token): finally: handle.destroy() + @unittest.skipIf(not TEST_CUDA, "triton") + def test_export_invoke_subgraph(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + recorded_list = [] + + @torch.library.custom_op("mylib::record_memory", mutates_args=()) + def record_memory(prefix: str, module_name: str) -> None: + torch.cuda.synchronize() + mem_alloc = torch.cuda.memory_allocated() / 1024**2 + mem_reserved = torch.cuda.memory_reserved() / 1024**2 + memory_str = f"[{prefix}] {module_name}: allocated={mem_alloc:.2f} MB, reserved={mem_reserved:.2f} MB" + recorded_list.append(memory_str) + + @record_memory.register_fake + def record_memory_fake(prefix, module_name): + return + + record_memory.register_effect(_EffectType.ORDERED) + + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(1024, 1024) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(1024, 1024) + + @torch.compiler.nested_compile_region + def forward(self, x): + torch.ops.mylib.record_memory("forward", "N") + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod_list = torch.nn.ModuleList(N() for _ in range(3)) + + def forward(self, x): + for m in self.mod_list: + x = m(x) + torch.ops.mylib.record_memory("forward", "N") + return (x,) + + model = M().to("cuda") + torch.cuda.reset_peak_memory_stats() + + x = torch.randn(32, 1024, requires_grad=True, device="cuda") + + ep = torch.export.export(model, (x,)) + ep = ep.run_decompositions() + self.assertEqual(len(list(ep.graph_module.named_modules())), 2) + + self.assertExpectedInline( + ep.graph_module.code.strip(), + """\ +def forward(self, token, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias, x): + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', token, x, p_mod_list_0_linear1_weight, p_mod_list_0_linear1_bias, p_mod_list_0_linear2_weight, p_mod_list_0_linear2_bias); repeated_subgraph0 = token = x = p_mod_list_0_linear1_weight = p_mod_list_0_linear1_bias = p_mod_list_0_linear2_weight = p_mod_list_0_linear2_bias = None + getitem = invoke_subgraph[0] + getitem_1 = invoke_subgraph[1]; invoke_subgraph = None + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, getitem_1, p_mod_list_1_linear1_weight, p_mod_list_1_linear1_bias, p_mod_list_1_linear2_weight, p_mod_list_1_linear2_bias); repeated_subgraph0_1 = getitem = getitem_1 = p_mod_list_1_linear1_weight = p_mod_list_1_linear1_bias = p_mod_list_1_linear2_weight = p_mod_list_1_linear2_bias = None + getitem_2 = invoke_subgraph_1[0] + getitem_3 = invoke_subgraph_1[1]; invoke_subgraph_1 = None + repeated_subgraph0_2 = self.repeated_subgraph0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, 'subgraph_0', getitem_2, getitem_3, p_mod_list_2_linear1_weight, p_mod_list_2_linear1_bias, p_mod_list_2_linear2_weight, p_mod_list_2_linear2_bias); repeated_subgraph0_2 = getitem_2 = getitem_3 = p_mod_list_2_linear1_weight = p_mod_list_2_linear1_bias = p_mod_list_2_linear2_weight = p_mod_list_2_linear2_bias = None + getitem_4 = invoke_subgraph_2[0] + getitem_5 = invoke_subgraph_2[1]; invoke_subgraph_2 = None + with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None + getitem_6 = with_effects[0]; with_effects = None + return (getitem_6, getitem_5)""", + ) + + self.assertExpectedInline( + ep.graph_module.repeated_subgraph0.code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): + with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.mylib.record_memory.default, 'forward', 'N'); arg0_1 = None + getitem = with_effects[0]; with_effects = None + permute = torch.ops.aten.permute.default(arg2_1, [1, 0]); arg2_1 = None + addmm = torch.ops.aten.addmm.default(arg3_1, arg1_1, permute); arg3_1 = arg1_1 = permute = None + relu = torch.ops.aten.relu.default(addmm); addmm = None + permute_1 = torch.ops.aten.permute.default(arg4_1, [1, 0]); arg4_1 = None + addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, permute_1); arg5_1 = relu = permute_1 = None + return (getitem, addmm_1)""", + ) + + recorded_list.clear() + out2 = ep.module()(x) + self.assertEqual(len(recorded_list), 4) + self.assertTrue(torch.allclose(model(x)[0], out2[0])) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5f0447c32264e..69f5eb92b58ce 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7437,6 +7437,50 @@ def forward(self, x): "RAIIAtenTensorHandle buf0(buf0_handle_restrided);" ).run(code) + def test_codegen_int_array_var_fix_memory_leak(self): + """ + Fix https://github.com/pytorch/pytorch/issues/167630 + """ + if self.device != "cuda": + raise unittest.SkipTest("test is only for cuda") + + def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): + layers = [] + d = in_dim + for _ in range(depth): + layers += [nn.Linear(d, hidden), nn.ReLU()] + d = hidden + layers += [nn.Linear(d, out_dim)] + return nn.Sequential(*layers) + + batch = 32 + in_dim = 2048 + hidden = 512 + out_dim = 10 + depth = 6 + + import gc + + allocated_memory = [] + for _ in range(3): + torch.cuda.reset_peak_memory_stats() + + model = make_mlp(in_dim, hidden, out_dim, depth).to(self.device) + example_inputs = (torch.randn(batch, in_dim, device=self.device),) + ep = torch.export.export( + model, + example_inputs, + ) + torch._inductor.aoti_compile_and_package(ep) + + del model, example_inputs, ep + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + allocated_memory.append(torch.cuda.memory_allocated()) + + self.assertTrue(allocated_memory[1] == allocated_memory[2]) + @unittest.skipIf(IS_MACOS, "might have no readelf on Mac") def test_libtorch_free_so(self): class Model(torch.nn.Module): diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index 55f54756913db..04799a506b42a 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -16,7 +16,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - skipIfRocm, skipIfXpu, ) from torch.testing._internal.inductor_utils import ( @@ -108,8 +107,6 @@ def test_softmax(self, expect_multi_kernel=True): self.assertFalse(_contains_multi_kernel_code(wrapper_code)) @requires_triton() - # TODO: bobrenjc93 to fix multi-kernel for ROCM - @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") @skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295") def test_triton_gemm(self): @@ -133,13 +130,14 @@ def fn(x, y): # One for the first pass and one for the second pass. # We mainly care about the wrapper for the final pass here. wrapper_code = wrapper_code[-1] - self.assertEqual(ref, act) + if torch.version.hip: + self.assertEqual(ref, act, atol=1e-3, rtol=1e-3) + else: + self.assertEqual(ref, act) self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code)) @skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295") @requires_triton() - # TODO: bobrenjc93 to fix multi-kernel for ROCM - @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_triton_relu_fused_gemm(self): def fn(x, y): @@ -162,7 +160,11 @@ def fn(x, y): # One for the first pass and one for the second pass. # We mainly care about the wrapper for the final pass here. wrapper_code = wrapper_code[-1] - self.assertEqual(ref, act) + if torch.version.hip: + self.assertEqual(ref, act, atol=1e-3, rtol=1e-3) + else: + self.assertEqual(ref, act) + self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code)) @parametrize("force_kernel", (0, 1)) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 1c9b39a1bd08d..d1b62feed3b41 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -828,9 +828,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, @@ -948,9 +945,6 @@ def wrapper_noop_set_seed(op, *args, **kwargs): "nn.functional.fractional_max_pool3d": {f16, f32, f64}, "nn.functional.group_norm": {f16}, "nn.functional.hinge_embedding_loss": {f16}, - # Enabling all tests for this test fails randomly - # See https://github.com/pytorch/pytorch/issues/129238 - "nn.functional.huber_loss": {f16}, "nn.functional.interpolate.bicubic": {f16}, "nn.functional.interpolate.bilinear": {f16}, "nn.functional.interpolate.trilinear": {f16}, diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 6284be2aebe9e..522a82cf9a222 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -357,6 +357,9 @@ def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], allowed_dtypes=(torch.cfloat, torch.cdouble)) + @toleranceOverride({ + torch.cfloat : tol(2e-4, 1.3e-6), + }) def test_reference_nd(self, device, dtype, op): if op.ref is None: raise unittest.SkipTest("No reference implementation") diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 4796153f24f05..e1a518aca6704 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -421,7 +421,7 @@ # inplace or out-variants) # If the function does not modify its arguments, we also check the following properties # pertaining to its output: -# 2) Its TensorImpl has use_count of 1 +# 2) Its TensorImpl has use_count of 1 (or 2 if it has a PyObject) # 3) If the function is a view function, it has the same StorageImpl as that of # the input it is aliased with. Otherwise, its StorageImpl has use_count of 1 # @@ -496,10 +496,10 @@ """ ) -ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate( +ENFORCE_TENSOR_IMPL_USE_COUNT = CodeTemplate( """\ if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) - TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}"); + TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() == expected_fresh_use_count(${tensor_name}), "function: ${fn_name}"); """ ) @@ -1664,7 +1664,7 @@ def check_tensorimpl_and_storage( if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT: stmts_after_call += [ - ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute( + ENFORCE_TENSOR_IMPL_USE_COUNT.substitute( tensor_name=ret_name, fn_name=type_wrapper_name(f) ) ] diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 23976a48473a3..d1de108283b11 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -47,6 +47,18 @@ namespace{ meta->grad_accumulator_.reset(); } } +[[maybe_unused]] size_t expected_fresh_use_count(const Variable& self) { + if (!self.defined()) { + // An UndefinedTensorImpl always has a use count of 0 + return 0; + } + if (self.unsafeGetTensorImpl()->pyobj_slot()->load_pyobj() != nullptr) { + // A TensorImpl with a Python object has a use count of 2 + return 2; + } + // A fresh TensorImpl (with no PyObject) has a use count of 1 + return 1; +} } namespace { diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 752bd594d066f..477b35b1811e4 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -100,7 +100,9 @@ class Logger: def _set_static_graph(self) -> None: ... class _WorkerServer: - def __init__(self, socket_path: str) -> None: ... + port: int + + def __init__(self, host_or_file: str, port: int = ...) -> None: ... def shutdown(self) -> None: ... def get_debug_level(): ... @@ -206,6 +208,7 @@ class Store: desired_value: str, ) -> bytes: ... def delete_key(self, key: str) -> bool: ... + def multi_get(self, keys: list[str]) -> list[bytes]: ... def num_keys(self) -> int: ... def set_timeout(self, timeout: timedelta): ... @overload @@ -872,3 +875,15 @@ class ProcessGroupXCCL(Backend): def _set_process_group(pg: ProcessGroup) -> None: ... def _current_process_group() -> ProcessGroup: ... + +class _Request: + def body(self) -> bytes: ... + def get_param(self, str) -> str: ... + +class _Response: + def set_content(self, content: str | bytes, content_type: str) -> None: ... + def set_status(self, status: int) -> None: ... + +def _register_handler( + name: str, handler: Callable[[_Request, _Response], None] +) -> None: ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index d60d89a6a4796..de12af50c1855 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -60,6 +60,7 @@ class _ExperimentalConfig: verbose: bool = ..., performance_events: list[str] = ..., enable_cuda_sync_events: bool = ..., + profile_all_threads: bool = ..., ) -> None: ... class ProfilerConfig: diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 63a72afa43a6d..1c6283e8a038f 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -201,11 +201,8 @@ def __post_init__(self, /) -> None: num_children = 0 else: assert callable(self._unflatten_func) - num_nodes = 1 - num_leaves = 0 - for child in self._children: - num_nodes += child.num_nodes - num_leaves += child.num_leaves + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 326178ef00874..16fa0997c7f83 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1266,6 +1266,19 @@ def method_to_local(self, *args, **kwargs): tx = InstructionTranslator.current_tx() # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function + + grad_placements_vt = kwargs.get( + "grad_placements", ConstantVariable.create(None) + ) + if isinstance(grad_placements_vt, variables.UserDefinedObjectVariable): + # grad_placement is a sequence-like structure, iterate over the value + grad_placements_vt = variables.BuiltinVariable(tuple).call_function( + tx, [grad_placements_vt], {} + ) + + if kwargs.get("grad_placements") is not None: + kwargs["grad_placements"] = grad_placements_vt + args_as_value = [x.as_python_constant() for x in args] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index b6ea08a802240..f17a516183975 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -33,7 +33,7 @@ handle_effect_tokens_fn, ) from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta -from .streams import assign_backward_streams +from .streams import assign_backward_streams, insert_backward_syncs from .utils import ( call_and_expect_output_descs, copy_fwd_metadata_to_bw_nodes, @@ -477,6 +477,8 @@ def aot_dispatch_autograd_graph( # After copying metadata, assign streams to gradient accumulation nodes assign_backward_streams(fx_g) + insert_backward_syncs(fx_g) + fx_g.graph.eliminate_dead_code() if not aot_config.disable_functionalization: # There should be *NO* mutating ops in the graph at this point. diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index f78a2c6cad1de..1b4f5ded051e3 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -3,6 +3,7 @@ import torch.fx import torch.fx.traceback from torch._dynamo.graph_utils import _get_flat_args +from torch._dynamo.variables.streams import get_current_stream, new_event Node: TypeAlias = torch.fx.Node @@ -12,6 +13,14 @@ def is_gradient_acc(node: Node) -> bool: return node.meta.get("is_gradient_acc", False) +def is_bwd_node(node: Node) -> bool: + return node.meta.get("partitioner_tag") == "is_backward" + + +def get_device(node: Node) -> torch.device: + return node.meta["val"].device + + def get_stream(node: Node) -> Optional[int]: maybe_annotation = node.meta.get("custom", None) if maybe_annotation is not None: @@ -20,6 +29,13 @@ def get_stream(node: Node) -> Optional[int]: return None +def get_stream_or_current_stream(node: Node) -> int: + ind = get_stream(node) + if ind is None: + ind = get_current_stream(get_device(node)) + return ind + + def set_stream(node: Node, ind: int) -> None: if "custom" in node.meta: node.meta["custom"].update({"stream": ind}) @@ -27,6 +43,36 @@ def set_stream(node: Node, ind: int) -> None: node.meta["custom"] = {"stream": ind} +def insert_sync( + graph: torch.fx.Graph, + consumer: Node, + producer: Node, + node_to_wait_event_ind: dict[Node, int], +) -> None: + if producer not in node_to_wait_event_ind: + node_to_wait_event_ind[producer] = new_event() + + with graph.inserting_after(producer): + node = graph.call_function( + torch.ops.streams.record_event.default, + ( + node_to_wait_event_ind[producer], + get_stream_or_current_stream(producer), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + with graph.inserting_before(consumer): + node = graph.call_function( + torch.ops.streams.wait_event.default, + ( + node_to_wait_event_ind[producer], + get_stream_or_current_stream(consumer), + ), + ) + node.meta["partitioner_tag"] = "must_be_in_backward" + + def assign_backward_streams(gm: torch.fx.GraphModule) -> None: """Assigns backward streams to gradient accumulation nodes""" @@ -51,3 +97,18 @@ def assign_backward_streams(gm: torch.fx.GraphModule) -> None: if ind is not None: set_stream(node, ind) break + + +def insert_backward_syncs(gm: torch.fx.GraphModule) -> None: + """Inserts stream syncs for backward nodes if consumer and producer are on different streams""" + node_to_wait_event_ind = {} + for node in gm.graph.nodes: + if is_bwd_node(node): + flat_args = _get_flat_args(node, {}) + cur_node_stream = get_stream(node) + + for arg in flat_args: + if is_bwd_node(arg): + arg_stream = get_stream(arg) + if arg_stream != cur_node_stream and get_device(arg).type != "cpu": + insert_sync(gm.graph, node, arg, node_to_wait_event_ind) diff --git a/torch/_guards.py b/torch/_guards.py index 32b796d71eea7..1bd32fc7f08ec 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -713,6 +713,9 @@ def __init__(self) -> None: self.lazy_bwd_cache: dict[ str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) + self.effects_cache: dict[ + str, set + ] = {} # Maps identifier -> set of effect types def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: self.dynamo_installed_submodules[fn_id].append(identifier) @@ -751,6 +754,21 @@ def get_lazy_bwd_entry( return self.lazy_bwd_cache[identifier].get(tangent_metadata, (None, None)) + def add_effects(self, identifier: str, effects: set) -> None: + """Store the effect types for a given invoke_subgraph identifier.""" + if prev_effects := self.effects_cache.get(identifier, None): + assert effects == prev_effects, ( + "Different number of effects were found for invoke_subgraph " + f"call with identifier {identifier}. \n" + f"Previously we had the following effects: {prev_effects}.\n" + f"But now we have: {effects}." + ) + self.effects_cache[identifier] = effects + + def get_effects(self, identifier: str) -> Optional[set]: + """Retrieve the effect types for a given invoke_subgraph identifier.""" + return self.effects_cache.get(identifier, None) + class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index e22b741631d3f..bb0d6cef3ee6f 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -80,6 +80,7 @@ def __call__( assert all( isinstance(o, (torch.Tensor, int, torch.SymInt, torch.Generator)) for o in operands + if o is not None ), ( f"invoke_subgraph operands must be a list of tensors/ints/SymInts/Generator {operands}" ) @@ -304,6 +305,62 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): def get_output_metadata(subgraph, *operands): + """ + Extract metadata about the subgraph outputs WITHOUT executing the subgraph. + This avoids running side-effectful operations twice (once here, once in forward). + We analyze the graph structure statically to extract metadata. + """ + # Unwrap FunctionalizeCtxWrapper if present + if isinstance(subgraph, FunctionalizeCtxWrapper): + subgraph = subgraph.subgraph + + # If not a GraphModule, fall back to execution-based metadata extraction + if not isinstance(subgraph, torch.fx.GraphModule): + return _get_output_metadata_by_execution(subgraph, *operands) + + output_metadata = OutputMetadata() + + # Extract output arguments from the output node + # The output node has args=(output_values,) where output_values is a tuple/list + output_node = next(reversed(subgraph.graph.find_nodes(op="output"))) + output_metadata.num_fw_outs = len(output_node.args[0]) + + for idx, output_arg in enumerate(output_node.args[0]): + if not isinstance(output_arg, torch.fx.Node): + if isinstance(output_arg, int): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + continue + + # Check node metadata for type information + if output_arg.meta.get("val") is None: + # If we don't have complete metadata for all outputs, fall back to execution + # This is important for correctness (e.g., detecting SymInts) even though it + # runs side-effectful operations + return _get_output_metadata_by_execution(subgraph, *operands) + + val = output_arg.meta["val"] + if isinstance(val, torch.SymInt): + output_metadata.indexes_with_symint.add(idx) + output_metadata.indexes_with_no_grad.add(idx) + elif isinstance(val, torch.Tensor): + # Check if tensor requires grad from metadata + if hasattr(val, "requires_grad") and not val.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + else: + # Non-tensor, non-symint (shouldn't happen but be safe) + output_metadata.indexes_with_no_grad.add(idx) + + return output_metadata + + +def _get_output_metadata_by_execution(subgraph, *operands): + """ + Fallback: Extract metadata by executing the subgraph. + This should only be used when static analysis fails. + WARNING: This will run side-effectful operations! + """ + with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): # args are functional tensors, generate some example tensors @@ -323,19 +380,15 @@ def get_output_metadata(subgraph, *operands): num_fw_outs = len(fw_outs) - # Collect the indexes of none in the output to check that the grad - # is None at the corresponding index in the backward. This check is - # performed in the autograd.Function - InvokeSubgraphAutogradOp. - # Also collect the indexes of no_grad in the output to filter out - # the grad_outs in the `backward` method. output_metadata = OutputMetadata() - output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): if isinstance(fw_out, torch.SymInt): output_metadata.indexes_with_symint.add(idx) elif not fw_out.requires_grad: output_metadata.indexes_with_no_grad.add(idx) + return output_metadata @@ -562,7 +615,34 @@ def _(ctx, subgraph, identifier, *operands): do_auto_functionalize_v2, ) + # (in the functionalization metadata phase) Capture tokens before + tokens_before = dict(ctx.mode._tokens) + + # Check if this subgraph has effects stored in the cache + invoke_subgraph_cache = get_invoke_subgraph_cache() + effects = None + if invoke_subgraph_cache: + effects = invoke_subgraph_cache.get_effects(identifier) + + if effects: + assert len(effects) == 1, "Multiple effects within a subgraph NYI" + tokens = ctx.mode._tokens + effects = next(iter(effects)) + token_input = tokens[effects] + + operands = (token_input, *operands) + + def wrap_subgraph(subgraph): + def wrapped_subgraph(token, *args): + res = subgraph(*args) + return ctx.unwrap_tensors(ctx.mode._tokens[effects]), *res + + return wrapped_subgraph + + subgraph = wrap_subgraph(subgraph) + unwrapped_operands = ctx.unwrap_tensors(operands) + hop_instance = HopInstance.create(invoke_subgraph, subgraph, identifier, *operands) if can_auto_functionalize(hop_instance): # NOTE: [auto_functionalize x invoke_subgraph caching] @@ -587,6 +667,28 @@ def _(ctx, subgraph, identifier, *operands): # of invoke_subgraph ops if input aliasing/mutation is detected. functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph) out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands) + + if effects: + (new_token, *out) = out + ctx.mode._tokens[effects] = new_token + + # (in the functionalization metadata phase) Capture tokens after and see if + # there are any differences (there are new effects or the token value for an + # effect type has changed) + tokens_after = dict(ctx.mode._tokens) + discovered_effects = set() + for effect_type, token in tokens_after.items(): + if effect_type not in tokens_before or tokens_before[effect_type] is not token: + discovered_effects.add(effect_type) + + if discovered_effects: + assert ctx.mode._allow_token_discovery, ( + f"Number of tokens changed by {len(discovered_effects)} when tracing subgraph {subgraph}." + ) + # Store discovered effects in the cache by identifier + if invoke_subgraph_cache: + invoke_subgraph_cache.add_effects(identifier, discovered_effects) + return ctx.wrap_tensors(out) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 1d1687141fb05..1b935283212ad 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -368,7 +368,10 @@ class TensorMeta: @classmethod def from_irnodes( - cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] + cls, + irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]], + *, + hint_override: Optional[int] = None, ) -> Union[TensorMeta, list[TensorMeta]]: if isinstance(irnodes, Sequence): result: list[Any] = [cls.from_irnodes(x) for x in irnodes] @@ -390,14 +393,17 @@ def from_irnodes( sizes=V.graph.sizevars.size_hints( node.get_size(), fallback=config.unbacked_symint_fallback, + hint_override=hint_override, ), strides=V.graph.sizevars.size_hints( node.get_stride(), fallback=config.unbacked_symint_fallback, + hint_override=hint_override, ), offset=V.graph.sizevars.size_hint( node.get_layout().offset, fallback=config.unbacked_symint_fallback, + hint_override=hint_override, ), name=node.get_name(), ) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 61a97fd740cbc..65d356dce0979 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -96,6 +96,7 @@ def __init__(self): self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] self._include_extra_header ) + self.codegen_int_array_var_cache = {} @staticmethod def create( @@ -1636,14 +1637,33 @@ def codegen_memory_format(self, memory_format): self.used_cached_memory_formats.add(memory_format_str) return f"cached_torch_memory_format_{memory_format_str}" - @functools.cache # noqa: B019 def codegen_int_array_var( self, int_array: str, writeline: Callable[..., None], known_statically=False, graph=None, # for per-graph caching - ): + ) -> str: + # Use id(graph) for caching to avoid circular references + cache_key = ( + int_array, + id(writeline), + known_statically, + id(graph) if graph else None, + ) + if cache_key not in self.codegen_int_array_var_cache: + self.codegen_int_array_var_cache[cache_key] = ( + self._codegen_int_array_var_impl(int_array, writeline, known_statically) + ) + + return self.codegen_int_array_var_cache[cache_key] + + def _codegen_int_array_var_impl( + self, + int_array: str, + writeline: Callable[..., None], + known_statically: bool, + ) -> str: # Used for size/stride declaration # # Because the memory planning is done in two passes (see the implementation diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 615913933326e..54d26eb7f1b76 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -19,7 +19,7 @@ SequentialComboKernelGrid, ) from ..scheduler import BaseSchedulerNode -from ..utils import Placeholder, triton_version_uses_attrs_dict +from ..utils import is_rocm, Placeholder, triton_version_uses_attrs_dict from ..virtualized import V from .common import ( ArgName, @@ -742,10 +742,13 @@ def kernel_benchmark_extra_args(self) -> list[str]: continue # pyrefly: ignore [missing-argument] if not tree.is_reduction or sub_kernel.inside_reduction: + meta_hint = sub_kernel.hint_override if is_rocm() else None extra_args.append( str( V.graph.sizevars.size_hint( - tree.numel, fallback=config.unbacked_symint_fallback + tree.numel, + fallback=config.unbacked_symint_fallback, + hint_override=meta_hint, ) ) ) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 681aef9afb35f..55279f393d3aa 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -341,12 +341,58 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: - sz_bytes = 0 - for node in fx_node.all_input_nodes: - if (t := node.meta.get("val")) is not None: - numel = get_size_numel(t.size()) - sz_bytes += numel * get_dtype_size(t.dtype) - return sz_bytes + """Estimate the size of a collective operation in bytes, including inputs and outputs.""" + input_bytes = None + + args, kwargs = fx_node.args, fx_node.kwargs + kwargs = dict(kwargs) + + # dont double count pre-allocated buffer passed in + kwargs.pop("out", None) + + def tensor_bytes(t) -> int: + return get_size_numel(t.size()) * get_dtype_size(t.dtype) + + def add_inp_bytes(inp: torch.fx.Node): + t = inp.meta.get("val", None) + if t is None: + return + + nonlocal input_bytes + if input_bytes is None: + input_bytes = 0 + input_bytes += tensor_bytes(t) + + pytree.tree_map_only( + torch.fx.Node, + add_inp_bytes, + (args, kwargs), + ) + + output_tensor = fx_node.meta.get("val", None) + + if input_bytes is None or output_tensor is None: + return 0 + + output_bytes = ( + get_size_numel(output_tensor.size()) * output_tensor.element_size() + ) # pyre-ignore + + return input_bytes + output_bytes + + +def estimate_fx_collective_memory_footprint(fx_node: torch.fx.Node) -> int: + """Estimate the memory footprint of a collective operation in bytes. + + This returns the total bytes that need to be live concurrently in memory. + For all_reduce, we divide by 2 since it can be done in-place. + """ + from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor as is_all_reduce, + ) + + size = estimate_fx_collective_size(fx_node) + return size if not is_all_reduce(fx_node) else size // 2 def estimate_nccl_collective_runtime_from_fx_node( diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 5641c4294356f..00737a3b6e3b7 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -489,15 +489,34 @@ def all_reduce_merge_fn_to_trace( return new_outs +# List of all torch dtypes for serialization through custom ops +# TODO: custom ops support list[dtype] input +_ALL_DTYPES = tuple( + [ + getattr(torch, attr) + for attr in dir(torch) + if isinstance(getattr(torch, attr), torch.dtype) + ] +) + + @torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={}) def _pre_bucket_all_gather( ag_ins: list[torch.Tensor], group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[ + int + ], # dtype enum values, that inputs are converted to before all_gather rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + # Convert int indices back to torch.dtype + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -507,8 +526,14 @@ def _pre_bucket_all_gather( new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) - ag_ins_flattened = [ag_in.reshape(-1).view(dtype) for ag_in in ag_ins] - torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + # View each destination slice as its output dtype, then copy + # The copy operation handles dtype conversion from input dtype to output dtype + foreach_copy_dsts_typed = [ + dst.view(out_dtype) + for dst, out_dtype in zip(foreach_copy_dsts, out_dtypes, strict=True) + ] + ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] + torch._foreach_copy_(foreach_copy_dsts_typed, ag_ins_flattened) return new_ag_out @@ -517,9 +542,14 @@ def _pre_bucket_all_gather_fake( group_size: int, group_name: str, dtype: torch.dtype, # type: ignore[name-defined] + out_dtype_ints: list[int], rank: int, ) -> torch.Tensor: - ins_split_sizes_bytes = [ag_in.numel() * ag_in.element_size() for ag_in in ag_ins] + out_dtypes = [_ALL_DTYPES[d] for d in out_dtype_ints] + ins_split_sizes_bytes = [ + ag_in.numel() * out_dtype.itemsize + for ag_in, out_dtype in zip(ag_ins, out_dtypes, strict=True) + ] bucket_dtype_size_bytes = dtype.itemsize ins_split_sizes = [ _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes @@ -541,12 +571,9 @@ def all_gather_merge_fn_to_trace_custom_ops( out_dtypes: list[torch.dtype], # type: ignore[name-defined] rank: int, ) -> list[torch.Tensor]: - ag_ins = [ - torch._prims.convert_element_type(_ag_in, out_dtype) - if _ag_in.dtype != out_dtype - else _ag_in - for _ag_in, out_dtype in zip(_ag_ins, out_dtypes) - ] + # Don't create convert_element_type ops - _pre_bucket_all_gather handles conversion + # by viewing destination slices as output dtypes and letting copy do the conversion + ag_ins = _ag_ins ins_sizes = [ag_in.shape for ag_in in ag_ins] ins_split_sizes_bytes = [ ag_in.numel() * out_dtype.itemsize @@ -557,8 +584,13 @@ def all_gather_merge_fn_to_trace_custom_ops( _bytes // bucket_dtype_size_bytes for _bytes in ins_split_sizes_bytes ] ag_input_numel = sum(ins_split_sizes) + + # Convert out_dtypes to indices for custom_op + # TODO: custom ops support list[dtype] input + out_dtype_ints = [_ALL_DTYPES.index(dt) for dt in out_dtypes] + new_ag_out = torch.ops.bucketing._pre_bucket_all_gather( - ag_ins, group_size, group_name, dtype, rank + ag_ins, group_size, group_name, dtype, out_dtype_ints, rank ) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) wait_tensor = torch.ops.c10d_functional.wait_tensor( @@ -721,6 +753,20 @@ def _insert_fn_trace_before_node( # type: ignore[no-untyped-def] return replacements, new_nodes +def has_mergeable_all_gather_convert_dtype(n: torch.fx.Node) -> bool: + node_in = n.args[0] + return ( + is_all_gather_into_tensor(n) + and isinstance(node_in, torch.fx.Node) + and node_in.op == "call_function" + and ( + node_in.target is torch.ops.prims.convert_element_type.default + or node_in.target is torch.ops.aten._to_copy.default + ) + and len(node_in.users) == 1 + ) + + def process_collective_bucket( g: torch.fx.Graph, bucket_nodes: list[torch.fx.Node], @@ -755,13 +801,7 @@ def process_collective_bucket( # Handle convert_element_type operations (for all_gather) node_in = n.args[0] - if ( - is_all_gather_into_tensor(n) - and isinstance(node_in, torch.fx.Node) # Add type check - and node_in.op == "call_function" - and node_in.target is torch.ops.prims.convert_element_type.default - and len(node_in.users) == 1 - ): + if has_mergeable_all_gather_convert_dtype(n): ag_node_to_pre_nodes[n].append(node_in) node_in = node_in.args[0] diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 4060a29c7c3db..b6cbf32bfba8e 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -3,12 +3,14 @@ from dataclasses import dataclass from typing import Any, Literal, Optional +import torch import torch.fx as fx from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( bucket_key, BucketMode, + has_mergeable_all_gather_convert_dtype, is_all_gather_into_tensor as is_all_gather, is_reduce_scatter_tensor as is_reduce_scatter, is_wait_tensor, @@ -207,6 +209,7 @@ def build_timeline(self, pg: str) -> Optional[PGEvent]: prev_event = event position += 1 + return head def _populate_node_to_event(self, pg: str) -> None: @@ -231,7 +234,6 @@ def _add_hiding_interval_constraints(self) -> None: self.aug_graph.add_extra_dep(n=info.wait_node, dep=hn) def bucket_collectives(self) -> None: - """Main entry point for bucketing collectives.""" # Group collectives by PG first pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) for start in self.collective_info: @@ -281,6 +283,15 @@ def bucket_collectives(self) -> None: # Apply topological sort with all dependencies from torch._dynamo.graph_deduplication import _stable_topological_sort + for n, deps in additional_deps.items(): + torch._check( + not n._erased, lambda: f"Erased node deps not transferred: {n}" + ) + for d in deps: + torch._check( + not d._erased, lambda: f"Erased node deps not transferred: {d}" + ) + _stable_topological_sort(self.graph, additional_deps) # After topological sort, preserve dependencies using effect tokens @@ -762,6 +773,11 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: old_starts = list(bucket) old_waits = [self.collective_info[n].wait_node for n in bucket] + fused_convert_dtypes = [] + for n in old_starts: + if has_mergeable_all_gather_convert_dtype(n): + fused_convert_dtypes.append(n.args[0]) + # Find where to place the bucketed operations next_node = bucket[0] while next_node in bucket: @@ -809,6 +825,22 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: for old_wait in old_waits: erased_to_new[old_wait] = new_wait + # Handle convert_element_type nodes that were fused and erased + # The bucketed operation may have a _pre_bucket op that handles dtype conversion + if fused_convert_dtypes: + # all gather bucketing may fuse in dtype conversion into the bucketing + # if so, we need to transfer hiding deps from the old dtype conversion + # to the new bucketing node + new_convert_dtypes_node = new_start.kwargs["out"] + assert isinstance(new_convert_dtypes_node, fx.Node) + assert ( + new_convert_dtypes_node.target + == torch.ops.bucketing._pre_bucket_all_gather.default + ) + + for n in fused_convert_dtypes: + erased_to_new[n] = new_convert_dtypes_node + # Transfer all dependencies from old nodes to new nodes self.aug_graph.transfer_erased_node_deps(erased_to_new) diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 0649e36f23361..b7617038f4e6a 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -11,7 +11,7 @@ import torch import torch.fx as fx from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.comm_analysis import estimate_fx_collective_size +from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor from torch._inductor.fx_passes.memory_estimator import ( _is_releasable, @@ -45,21 +45,26 @@ def get_group_name(n: fx.Node) -> str: def get_custom_estimation( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + override_size: int | None = None, ) -> float | None: if custom_runtime_estimation is None: return None - return custom_runtime_estimation(n) + return custom_runtime_estimation(n, override_size) def estimate_collective_time( n: fx.Node, override_size: int | None = None, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: """Estimate the runtime of a collective operation, optionally with an overridden size.""" - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, override_size) + ) is not None: return est # Use analytical model (benchmarking is handled separately in alignment) @@ -99,7 +104,8 @@ def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]: def benchmark_node_with_cache_key( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> tuple[float, str | None]: """Benchmark a compute node and return (runtime, cache_key).""" assert is_compute_node(n) @@ -142,7 +148,9 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: if unbacked_tensor: return 0, key - if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: + if ( + est := get_custom_estimation(n, custom_runtime_estimation, None) + ) is not None: set_cached_node_time(key, est) return est, key @@ -154,7 +162,8 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: def benchmark_node( n: fx.Node, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, ) -> float: return benchmark_node_with_cache_key(n, custom_runtime_estimation)[0] @@ -236,7 +245,7 @@ def __init__( insert_overlap_deps: bool, compute_overlap_multipler: float, max_coll_distance: int, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, collective_estimator: Literal["analytical", "benchmark"], ): self.gm = gm @@ -318,7 +327,7 @@ def _identify_collectives(self) -> None: info = CollectiveInfo( start_node=start, wait_node=node, - size_bytes=estimate_fx_collective_size(start), + size_bytes=estimate_fx_collective_memory_footprint(start), estimated_time_ms=coll_time_ms, exposed_time_ms=coll_time_ms, # Initially fully exposed ) @@ -431,7 +440,10 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( # Benchmark CUDA events (non-deterministic, needs alignment) # Skip collectives with custom estimation for n in collective_nodes: - if get_custom_estimation(n, self.custom_runtime_estimation) is not None: + if ( + get_custom_estimation(n, self.custom_runtime_estimation, None) + is not None + ): continue # Benchmark actual size @@ -1000,7 +1012,8 @@ def schedule_overlap_bucketing( insert_overlap_deps: bool = False, compute_overlap_multipler: float = 1.0, max_coll_distance: int = 1000, - custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, collective_estimator: Literal["analytical", "benchmark"] = "analytical", ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d6893b07ee3d9..91b0e5ec66053 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1495,10 +1495,12 @@ def call_kernel( wrapper.generate_workspace_deallocation(self.workspace_arg) def kernel_benchmark_extra_args(self) -> list[str]: + meta_hint = self.hint_override if torch.version.hip else None return [ str(x) for x in self.grid_fn( - *V.graph.sizevars.size_hints(self.call_sizes), self.meta + *V.graph.sizevars.size_hints(self.call_sizes, hint_override=meta_hint), + self.meta, ) ] diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 9df8d114ef67b..e2e2fa9288a08 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -646,6 +646,10 @@ def _get_exceeding_shared_memory_checker( If the device does not report available shared memory, returns None. """ + from ..utils import get_gpu_shared_memory + + sm_available = None + try: device = torch.cuda.current_device() props = torch.cuda.get_device_properties(device) @@ -653,8 +657,16 @@ def _get_exceeding_shared_memory_checker( return None sm_available = int(props.shared_memory_per_block_optin) except Exception: - # If CUDA is not available or properties cannot be queried, return None - return None + pass + + # ROCm specific logic to get shared memory + if torch.version.hip and sm_available is None: + try: + sm_available = get_gpu_shared_memory() + if sm_available == 0: + return None + except Exception: + return None # TODO make a BaseDeviceConfigHeuristics to handle different device configuration in its own implementation. def exceeds(gemm_config: BaseConfig, dtype_size: int) -> bool: @@ -1318,6 +1330,7 @@ def _finalize_mm_configs( waves_per_eu, matrix_instr_nonkdim, kpack, + conf.hint_override, ) # Check if gemm specific arg exists - add to key if does @@ -1344,7 +1357,12 @@ def _finalize_mm_configs( } if group_m is not None: kwargs["GROUP_M"] = group_m - yield self.triton_config(**kwargs) + + tc = self.triton_config(**kwargs) + # Preserve hint_override for multi-kernel support + if hasattr(conf, "hint_override") and conf.hint_override is not None: + tc.hint_override = conf.hint_override + yield tc def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfig]: flex_attn_fwd_configs: list[FlexConfig] = [] @@ -1674,6 +1692,12 @@ def _convert_config_to_template_kwargs( group_m = triton_config.kwargs.get("GROUP_M", 8) options_dict["GROUP_M"] = group_m + # Keep ROCm multi-kernel size bucket attached to the config + if torch.version.hip and "hint_override" not in options_dict: + hint_override = getattr(triton_config, "hint_override", None) + if hint_override is not None: + options_dict["hint_override"] = hint_override + return options_dict def _get_acc_type(self, dtype: torch.dtype) -> str: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index f029a2e73f038..f00852721a42b 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -3030,6 +3030,10 @@ def is_gpu(device: Optional[str]) -> bool: return device in GPU_TYPES +def is_rocm() -> bool: + return torch.version.hip is not None + + def device_need_guard(device: str) -> bool: return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now diff --git a/torch/_library/effects.py b/torch/_library/effects.py index 41fbaa4c1c7b4..3f765f380eab1 100644 --- a/torch/_library/effects.py +++ b/torch/_library/effects.py @@ -35,6 +35,18 @@ def _set_default_effect(self) -> None: if namespace == "higher_order": return + # These classes do not have side effects as they just store quantization + # params, so we dont need to mark them as ordered + skip_classes = ( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase", + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase", + "__torch__.torch.classes.quantized.EmbeddingPackedParamsBase", + "__torch__.torch.classes.quantized.LinearPackedParamsBase", + "__torch__.torch.classes.xnnpack.Conv2dOpContext", + "__torch__.torch.classes.xnnpack.LinearOpContext", + "__torch__.torch.classes.xnnpack.TransposeConv2dOpContext", + ) + opname = f"{namespace}::{opname}" if torch._C._get_operation_overload(opname, overload) is not None: # Since we call this when destroying the library, sometimes the @@ -42,6 +54,9 @@ def _set_default_effect(self) -> None: schema = torch._C._get_schema(opname, overload) for arg in schema.arguments: if isinstance(arg.type, torch.ClassType): + type_str = arg.type.str() # pyrefly: ignore[missing-attribute] + if type_str in skip_classes: + continue self._effect = EffectType.ORDERED return diff --git a/torch/_subclasses/complex_tensor/__init__.py b/torch/_subclasses/complex_tensor/__init__.py new file mode 100644 index 0000000000000..1ab4a816261dc --- /dev/null +++ b/torch/_subclasses/complex_tensor/__init__.py @@ -0,0 +1,9 @@ +from ._core import ComplexTensor +from ._ops import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensor", "ComplexTensorMode", "is_complex_tensor"] + +ComplexTensor.__module__ = __name__ +ComplexTensorMode.__module__ = __name__ +is_complex_tensor.__module__ = __name__ diff --git a/torch/_subclasses/complex_tensor/_core.py b/torch/_subclasses/complex_tensor/_core.py new file mode 100644 index 0000000000000..edd7568b2ef06 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_core.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Any, TYPE_CHECKING +from typing_extensions import Self + +import torch +from torch import Tensor +from torch.autograd import Function + + +if TYPE_CHECKING: + from torch._ops import OpOverload + from torch._prims_common import DeviceLikeType + from torch.autograd.function import FunctionCtx + + +class ComplexTensor(Tensor): + """A class that decomposes all ops on complex Tensors into their real and imaginary parts.""" + + _re: Tensor + _im: Tensor + + def __new__(cls, real: Tensor, imag: Tensor) -> Self: + """Initialize a ComplexTensor from its real and imaginary parts.""" + from ._ops.common import REAL_TO_COMPLEX + + shape = real.shape + device = real.device + + # TODO (hameerabbasi): `torch.compile` sometimes fails here without making these + # contiguous. Why? + real = real.contiguous() + imag = imag.contiguous() + + # TODO (hameerabbasi): + # What should we do with dtype? + # We could convert to the complex type (float32 -> complex64), but we + # can't use that model for say `bfloat16` which does not have a + # corresponding complex dtype. + # If we want to support this complex rep using any float type (see + # https://github.com/pytorch/pytorch/issues/95100) + # We either need to: + # 1) add the complex types for say `complexbf32`, knowing they can't really be used anywhere + # else. + # 2) We use the real float dtype here, and it is up to the user to know + # that dtype=float here really means complex<2xSize> with dtype + # matching that of re/im parts alone + # I'm going with 1 for now, so that I can make gradcheck and some complex + # ops work properly, but might want to discuss this in the RFP. + dtype = REAL_TO_COMPLEX.get(real.dtype) + if dtype is None: + raise TypeError( + "Unsupported dtype for constituent tensors. Supported dtypes are: " + f"{set(REAL_TO_COMPLEX.keys())!r}." + ) + storage_offset = real.storage_offset() + strides = real.stride() + layout = real.layout + pin_memory = real.is_pinned() + + assert shape == imag.shape, f"Expected imag shape {shape}, got {imag.shape}" + assert device == imag.device, ( + f"Expected imag device {device}, got {imag.device}" + ) + assert real.dtype == imag.dtype, ( + f"Expected imag dtype {real.dtype}, got {imag.dtype}" + ) + assert pin_memory == imag.is_pinned(), ( + f"Expected imag pinning {pin_memory}, got {imag.is_pinned()}" + ) + + res = Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + shape, + device=device, + dtype=dtype, + storage_offset=storage_offset, + strides=strides, + pin_memory=pin_memory, + layout=layout, + requires_grad=False, + ) + res._re = real.clone().detach() + res._im = imag.clone().detach() + + return res + + @property + def re(self) -> Tensor: + return self._re + + @property + def im(self) -> Tensor: + return self._im + + @classmethod + def __torch_dispatch__( + cls, + func: OpOverload, + types: tuple[type, ...], + args: tuple = (), + kwargs: dict | None = None, + ): + from ._ops.common import lookup_complex + + kwargs = {} if kwargs is None else kwargs + + impl = lookup_complex(func, *args, **kwargs) + if impl is None: + return NotImplemented + + return impl(*args, **kwargs) + + @staticmethod + def from_interleaved(t: Tensor) -> ComplexTensor: + t_real = torch.real(t) + t_imag = torch.imag(t) if t.dtype.is_complex else torch.zeros_like(t_real) + return Complex.apply(t_real, t_imag) + + def as_interleaved(self) -> Tensor: + return torch.complex(self.real, self.imag) + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, Tensor], + meta: Any, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + ) -> ComplexTensor: + assert meta is None + re, im = inner_tensors["re"], inner_tensors["im"] + return ComplexTensor(re, im) + + def __tensor_flatten__(self) -> tuple[list[str], Any]: + return ["re", "im"], None + + def __repr__(self, *, tensor_contents=None) -> str: + return f"ComplexTensor(real={self.re!r}, imag={self.im!r})" + + def is_pinned(self, device: DeviceLikeType | None = None) -> bool: + return self.re.is_pinned(device) + + +class Complex(Function): + @staticmethod + def forward(ctx: FunctionCtx, real: Tensor, imag: Tensor) -> ComplexTensor: # type: ignore[bad-override] + return ComplexTensor(real, imag) + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: ComplexTensor) -> tuple[Tensor, Tensor]: # type: ignore[bad-override] + return grad_output.real, grad_output.imag diff --git a/torch/_subclasses/complex_tensor/_ops/__init__.py b/torch/_subclasses/complex_tensor/_ops/__init__.py new file mode 100644 index 0000000000000..c07bdf6099b65 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/__init__.py @@ -0,0 +1,5 @@ +from . import aten, prims +from .common import ComplexTensorMode, is_complex_tensor + + +__all__ = ["ComplexTensorMode", "is_complex_tensor", "aten", "prims"] diff --git a/torch/_subclasses/complex_tensor/_ops/aten.py b/torch/_subclasses/complex_tensor/_ops/aten.py new file mode 100644 index 0000000000000..15e09c3b314f0 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/aten.py @@ -0,0 +1,921 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .._core import ComplexTensor +from .common import ( + _get_func_name, + COMPLEX_TO_REAL, + complex_to_real_dtype, + is_complex, + OpType, + promote_tensors, + register_binary_nonlinear, + register_complex, + register_error, + register_force_test, + register_simple, + split_complex_arg, + split_complex_tensor, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from typing import Any + +aten = torch.ops.aten + + +def register_binary_linear(op: OpType): + def impl_with_alpha( + lhs: ComplexTensor, rhs: ComplexTensor, *args, alpha, **kwargs + ) -> ComplexTensor: + return op(lhs, aten.mul(rhs, alpha, *args, **kwargs), *args, **kwargs) + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + return impl_with_alpha(lhs, rhs, *args, alpha=alpha, **kwargs) + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + u = op(a_r, b_r, *args, **kwargs) + v = op(a_i, b_i, *args, **kwargs) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + return register_complex(op, impl) + + +@register_complex(aten.real) +def real_impl(self: ComplexTensor) -> torch.Tensor: + re, _ = split_complex_tensor(self) + return re + + +@register_complex(aten.imag) +def imag_impl(self: ComplexTensor) -> torch.Tensor: + _, im = split_complex_tensor(self) + return im + + +@register_complex(aten.is_pinned) +def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> bool: + return self.is_pinned(device) + + +SIMPLE_OPS_LIST = [ + aten.slice, + aten.flatten, + aten.view, + aten.diagonal, + aten.expand, + aten.unsqueeze, + aten.unsqueeze_, + aten.mean, + aten.sum, + aten.clone, + aten.neg, + aten.flip, + aten.permute, + aten.repeat, + aten.index_select, + aten.split, + aten.split_with_sizes, + aten.cumsum, + aten.detach, + aten.select, + aten.squeeze, + aten.zero_, + aten.transpose, + aten.t, + aten.gather, +] + +for simple_op in SIMPLE_OPS_LIST: + globals()[_get_func_name(simple_op)] = register_simple(simple_op) + +# TODO (hameerabbasi): Not being tested +SIMPLE_FORCE_TESTED_OPS = [ + aten.copy, + aten.col2im, + aten.alias, + aten.lift_fresh, + aten._unsafe_view, + aten.index, + aten._neg_view, + aten.avg_pool2d, + aten.avg_pool3d, + aten.avg_pool2d_backward, + aten.avg_pool3d_backward, + aten.masked_scatter_backward, + aten.select_backward, + aten.slice_backward, + aten.embedding, +] + +for simple_op in SIMPLE_FORCE_TESTED_OPS: + globals()[_get_func_name(simple_op)] = register_force_test( + simple_op, register_simple(simple_op) + ) + +del simple_op + +# some binary ops which we can stamp out +mul_impl = register_binary_nonlinear(aten.mul) +mul__impl = register_binary_nonlinear(aten.mul_) +mm_impl = register_binary_nonlinear(aten.mm) +dot_impl = register_binary_nonlinear(aten.dot) +bmm_impl = register_binary_nonlinear(aten.bmm) + +# TODO (hameerabbasi): Not being tested +convolution_impl = register_force_test( + aten.convolution, register_binary_nonlinear(aten.convolution) +) + +slice_scatter_impl = register_force_test( + aten.slice_scatter, register_binary_linear(aten.slice_scatter) +) +select_scatter_impl = register_force_test( + aten.select_scatter, register_binary_linear(aten.select_scatter) +) + +add_impl = register_binary_linear(aten.add) +add__impl = register_binary_linear(aten.add_) +sub_impl = register_binary_linear(aten.sub) +sub__impl = register_binary_linear(aten.sub_) +diagonal_scatter_impl = register_binary_linear(aten.diagonal_scatter) +fill__impl = register_binary_linear(aten.fill_) + + +@register_complex(aten.rsub) +def rsub_impl(lhs: ComplexTensor, rhs: ComplexTensor, alpha=None) -> ComplexTensor: + if alpha is None: + return torch.sub(rhs, lhs) # type: ignore[bad-return] + return torch.sub(rhs, lhs, alpha=alpha) # type: ignore[bad-return] + + +@register_complex(aten.div) +@register_complex(aten.true_divide) +def div_impl(lhs: ComplexTensor, rhs: ComplexTensor, *, rounding_mode=None): + if rounding_mode is not None: + raise NotImplementedError( + "`rounding_mode` other than `None` not implemented for`ComplexTensor`." + ) + a_r, a_i = split_complex_tensor(lhs) + if not is_complex(rhs): + return ComplexTensor(a_r / rhs, a_i / rhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + num_r = a_r * b_r + a_i * b_i + num_i = a_i * b_r - a_r * b_i + den = b_r * b_r + b_i * b_i + return ComplexTensor( + (num_r / den).to(out_dt), + (num_i / den).to(out_dt), + ) + + +@register_complex(aten.reciprocal) +def reciprocal_impl(self: ComplexTensor): + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + den = self_r * self_r + self_i * self_i + return ComplexTensor( + aten.div(self_r, den).to(out_dt), + aten.div(-self_i, den).to(out_dt), + ) + + +# reductions +@register_complex(aten.prod) +def prod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + dtype = kwargs.pop("dtype", out_dt) + kwargs["dtype"] = complex_to_real_dtype(self.dtype) + + prod_r = torch.prod(torch.abs(self), *args, **kwargs) + sum_phi = torch.sum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v).to(dtype) # type: ignore[bad-return] + + +@register_complex(aten.pow) +def pow_impl(self: ComplexTensor, exponent: ComplexTensor) -> ComplexTensor: + out_dt, (self, exponent) = promote_tensors(self, exponent) + return torch.exp(exponent * torch.log(self)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.cumprod) +def cumprod_impl(self: ComplexTensor, *args, **kwargs) -> ComplexTensor: + dtype = kwargs.pop("dtype", self.dtype) + kwargs["dtype"] = complex_to_real_dtype(dtype) + + prod_r = torch.cumprod(torch.abs(self), *args, **kwargs) + sum_phi = torch.cumsum(torch.angle(self), *args, **kwargs) + u = prod_r * torch.cos(sum_phi) + v = prod_r * torch.sin(sum_phi) + return ComplexTensor(u, v) + + +# unary funcs, +# most of these are simple or require some kind of identity +@register_complex(aten.abs) +def abs_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + result = torch.hypot(x, y) + return result.to(out_dt) + + +@register_complex(aten.angle) +def angle_impl(self: ComplexTensor) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.atan2(y, x) + + +@register_complex(aten.acos) +def acos_impl(self: ComplexTensor) -> ComplexTensor: + _, y = split_complex_tensor(self) + acosh_z = torch.acosh(self) + assert isinstance(acosh_z, ComplexTensor) + acosh_z_re, acosh_z_im = split_complex_tensor(acosh_z) + sign_im = 2 * torch.signbit(y) - 1 + return ComplexTensor(torch.abs(acosh_z_im), sign_im * torch.abs(acosh_z_re)) + + +@register_complex(aten.asin) +def asin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + asinh_iz = torch.asinh(ComplexTensor(-y, x)) + assert isinstance(asinh_iz, ComplexTensor) + asinh_iz_re, asinh_iz_im = split_complex_tensor(asinh_iz) + return ComplexTensor(asinh_iz_im, -asinh_iz_re) + + +@register_complex(aten.atan) +def atan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.atanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.asinh) +def asinh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self + 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.acosh) +def acosh_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + return torch.log(self + torch.sqrt(self * self - 1)).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.atanh) +def atanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + ret = 0.5 * ( + torch.log(ComplexTensor(1 + x, y)) - torch.log(ComplexTensor(1 - x, -y)) + ) + assert isinstance(ret, ComplexTensor) + ret_re, ret_im = split_complex_tensor(ret) + + return ComplexTensor(ret_re.to(out_dt), ret_im.to(out_dt)) + + +@register_complex(aten.cos) +def cos_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + return torch.cosh(ComplexTensor(-y, x)) # type: ignore[bad-return] + + +@register_complex(aten.cosh) +def cosh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.cosh(x) * torch.cos(y) + v = torch.sinh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.sin) +def sin_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + sinh_iz = torch.sinh(ComplexTensor(-y, x)) + assert isinstance(sinh_iz, ComplexTensor) + sinh_iz_re, sinh_iz_im = split_complex_tensor(sinh_iz) + return ComplexTensor(sinh_iz_im, -sinh_iz_re) + + +@register_complex(aten.sinh) +def sinh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + u = torch.sinh(x) * torch.cos(y) + v = torch.cosh(x) * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.tan) +def tan_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + tanh_iz = torch.tanh(ComplexTensor(-y, x)) + assert isinstance(tanh_iz, ComplexTensor) + tanh_iz_re, tanh_iz_im = split_complex_tensor(tanh_iz) + return ComplexTensor(tanh_iz_im, -tanh_iz_re) + + +@register_complex(aten.tanh) +def tanh_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + + _2x = 2 * x + _2y = 2 * y + _d = torch.cosh(_2x) + torch.cos(_2y) + _2xsh = torch.sinh(_2x) + + out_re = _2xsh / _d + out_im = torch.sin(_2y) / _d + + return ComplexTensor(out_re.to(out_dt), out_im.to(out_dt)) + + +@register_complex(aten.exp) +def exp_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + ex = torch.exp(x) + u = ex * torch.cos(y) + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.expm1) +def expm1_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + out_dt, (x, y) = promote_tensors(x, y) + # TODO (hameerabbasi): The two lines below may have numerical issues + ex = torch.exp(x) + u = ex * torch.cos(y) - 1 + v = ex * torch.sin(y) + return ComplexTensor(u.to(out_dt), v.to(out_dt)) + + +@register_complex(aten.log) +def log_impl(self: ComplexTensor) -> ComplexTensor: + out_dt, (self,) = promote_tensors(self) + re = torch.log(torch.abs(self)) + im = torch.angle(self) + return ComplexTensor(re, im).to(out_dt) # type: ignore[bad-return] + + +@register_complex(aten.log1p) +def log1p_impl(self: ComplexTensor) -> ComplexTensor: + x, y = split_complex_tensor(self) + # TODO (hameerabbasi): The line below may have numerical issues + return torch.log(ComplexTensor(x + 1, y)) # type: ignore[bad-return] + + +@register_complex(aten.any) +def any_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) | torch.any(y, *args, **kwargs) + + +@register_complex(aten.all) +def all_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + x, y = split_complex_tensor(self) + return torch.any(x, *args, **kwargs) & torch.any(y, *args, **kwargs) + + +@register_complex(aten.eq) +def eq_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_arg(self) + b_r, b_i = split_complex_arg(rhs) + return torch.eq(a_r, b_r, *args, **kwargs) & torch.eq(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.ne) +def ne_impl(self: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> torch.Tensor: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + return torch.ne(a_r, b_r, *args, **kwargs) | torch.ne(a_i, b_i, *args, **kwargs) + + +@register_complex(aten.isnan) +def isnan_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isnan(re) | torch.isnan(im) + + +@register_complex(aten.isinf) +def isinf_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isinf(re) | torch.isinf(im) + + +@register_complex(aten.isfinite) +def isfinite_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.isfinite(re) & torch.isfinite(im) + + +@register_complex(aten.isclose) +def isclose_impl( + self: ComplexTensor, + rhs: ComplexTensor, + rtol=1e-5, + atol=1e-8, + equal_nan: bool = False, +) -> torch.Tensor: + abs_diff = torch.abs(self - rhs) + abs_other = torch.abs(rhs) + basic_condition = abs_diff <= (rtol * abs_other + atol) + + # This is the nontrivial part + if equal_nan: + a_r, a_i = split_complex_tensor(self) + b_r, b_i = split_complex_arg(rhs) + + a_r_nan = torch.isnan(a_r) + b_r_nan = torch.isnan(b_r) + a_i_nan = torch.isnan(a_i) + b_i_nan = torch.isnan(b_i) + a_nan = a_r_nan | a_i_nan + + # This logical expression makes sure that the isnan of both the real and imaginary parts + # matches (so 1 + nan*i doesn't equal nan + 1*i) + equal_nan_condition = ((a_r_nan == b_r_nan) & (a_i_nan == b_i_nan)) & a_nan + return basic_condition | equal_nan_condition + + return basic_condition + + +ERROR_OPS_LIST = [ + aten.lt, + aten.le, + aten.gt, + aten.ge, + aten.amin, + aten.amax, + aten.clamp, + aten.ceil, + aten.floor, + aten.minimum, + aten.maximum, + aten.trunc, + aten.sign, + aten.argmax, + aten.argmin, + aten.sort, + aten.topk, + aten.round, + aten.fmod, +] + + +ERROR_TYPES = { + aten.minimum: RuntimeError, + aten.maximum: RuntimeError, + aten.argmax: RuntimeError, + aten.argmin: RuntimeError, + aten.sort: RuntimeError, + aten.topk: RuntimeError, +} + + +for err_op in ERROR_OPS_LIST: + globals()[_get_func_name(err_op)] = register_error( + err_op, ERROR_TYPES.get(err_op, NotImplementedError) + ) + +del err_op + + +@register_complex(aten.masked_scatter) +def masked_scatter_impl( + self: ComplexTensor, mask: torch.Tensor, source: ComplexTensor +) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + source_r, source_i = split_complex_arg(source) + ret_r = torch.masked_scatter(self_r, mask, source_r) + ret_i = torch.masked_scatter(self_i, mask, source_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.where) +def where_impl(mask: torch.Tensor, x: ComplexTensor, y: ComplexTensor) -> ComplexTensor: + x_r, x_i = split_complex_arg(x) + y_r, y_i = split_complex_arg(y) + + ret_r = torch.where(mask, x_r, y_r) + ret_i = torch.where(mask, x_i, y_i) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.full_like) +def full_like_impl( + input: ComplexTensor, + fill_value: complex, + *args, + dtype: torch.dtype | None = None, + **kwargs, +) -> torch.Tensor | ComplexTensor: + # Note: Cannot be merged with the cases below due to the `fill_value` argument + input_r, input_i = split_complex_tensor(input) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return torch.full_like(input_r, fill_value, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + fv_r, fv_i = split_complex_arg(fill_value) + ret_r = torch.full_like(input_r, fv_r, *args, **kwargs) + ret_i = torch.full_like(input_i, fv_i, *args, **kwargs) + + return ComplexTensor(ret_r, ret_i) + + +def register_like(op: OpType) -> Callable[..., torch.Tensor | ComplexTensor]: + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> torch.Tensor | ComplexTensor: + self_re, self_im = split_complex_tensor(self) + + if dtype is not None and dtype not in COMPLEX_TO_REAL: + return op(self_re, *args, dtype=dtype, **kwargs) + + if dtype is not None: + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + ret_re = op(self_re, *args, **kwargs) + ret_im = op(self_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +LIKE_OPS_LIST = [ + aten.empty_like, + aten.zeros_like, + aten.randn_like, + aten.new_zeros, +] + +for like_op in LIKE_OPS_LIST: + globals()[_get_func_name(like_op)] = register_like(like_op) + +del like_op + + +@register_complex(aten.cat) +def cat_impl(tensors: Sequence[ComplexTensor], dim: int = 0) -> ComplexTensor: + tensors_r = [] + tensors_i = [] + + for t in tensors: + t_r, t_i = split_complex_arg(t) + tensors_r.append(t_r) + tensors_i.append(t_i) + + ret_r = torch.cat(tensors_r, dim=dim) + ret_i = torch.cat(tensors_i, dim=dim) + + return ComplexTensor(ret_r, ret_i) + + +@register_complex(aten.sgn) +def sgn_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + abs_self = torch.abs(ComplexTensor(self_r, self_i)) + mask = (self_r != 0) | (self_i != 0) + masked_sgn = ComplexTensor( + (self_r / abs_self).to(out_dt), (self_i / abs_self).to(out_dt) + ) + return torch.where(mask, masked_sgn, 0) # type: ignore[bad-return] + + +@register_complex(aten.sqrt) +def sqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_sqrt = torch.sqrt(torch.abs(self)) + self_half_angle = 0.5 * torch.angle(self) + + ret_r = self_abs_sqrt * torch.cos(self_half_angle) + ret_i = self_abs_sqrt * torch.sin(self_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.rsqrt) +def rsqrt_impl(self: ComplexTensor) -> ComplexTensor: + self_r, self_i = split_complex_tensor(self) + out_dt, (self_r, self_i) = promote_tensors(self_r, self_i) + self = ComplexTensor(self_r, self_i) + self_abs_rsqrt = torch.rsqrt(torch.abs(self)) + self_neg_half_angle = -0.5 * torch.angle(self) + + ret_r = self_abs_rsqrt * torch.cos(self_neg_half_angle) + ret_i = self_abs_rsqrt * torch.sin(self_neg_half_angle) + + return ComplexTensor(ret_r.to(out_dt), ret_i.to(out_dt)) + + +@register_complex(aten.addmm) +def addmm_impl( + input: ComplexTensor, + mat1: ComplexTensor, + mat2: ComplexTensor, + out_dtype: torch.dtype | None = None, + beta: complex = 1, + alpha: complex = 1, +) -> ComplexTensor: + ret = beta * input + alpha * torch.mm(mat1, mat2) + assert isinstance(ret, ComplexTensor) + ret_r, ret_i = split_complex_tensor(ret) + if out_dtype is not None: + out_dtype = COMPLEX_TO_REAL[out_dtype] + ret_r, ret_i = ret_r.to(out_dtype), ret_i.to(out_dtype) + return ComplexTensor(ret_r, ret_i) + + +def elemwise_nonzero(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return (re != 0) | (im != 0) + + +def register_nonzero_impl(op: OpType): + def nonzero_impl( + self: ComplexTensor, other: ComplexTensor, *args, **kwargs + ) -> torch.Tensor: + return op(elemwise_nonzero(self), elemwise_nonzero(other), *args, **kwargs) + + func_name = _get_func_name(op) + nonzero_impl.__name__ = func_name + nonzero_impl.__qualname__ = func_name + + return register_complex(op, nonzero_impl) + + +logical_and_impl = register_nonzero_impl(aten.logical_and) +logical_or_impl = register_nonzero_impl(aten.logical_or) +logical_xor_impl = register_nonzero_impl(aten.logical_xor) + + +@register_complex(aten.logical_not) +def logical_not_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.logical_not(elemwise_nonzero(self), *args, **kwargs) + + +@register_complex(aten.view_as_real) +def view_as_real_impl(self: ComplexTensor) -> torch.Tensor: + re, im = split_complex_tensor(self) + return torch.stack([re, im], dim=-1) + + +@register_complex(aten.linalg_vector_norm) +def linalg_vector_norm_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + return torch.linalg.vector_norm(torch.abs(self), *args, **kwargs) + + +@register_force_test(aten.copy_) +def copy__impl(self: ComplexTensor, src, *args, **kwargs): + self_re, self_im = split_complex_tensor(self) + src_re, src_im = split_complex_arg(src) + + ret_re = self_re.copy_(src_re, *args, **kwargs) + ret_im = self_im.copy_(src_im, *args, **kwargs) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten._local_scalar_dense) +def _local_scalar_dense_impl(self: ComplexTensor, *args, **kwargs) -> complex: + x, y = split_complex_tensor(self) + u = aten._local_scalar_dense(x, *args, **kwargs) + v = aten._local_scalar_dense(y, *args, **kwargs) + return complex(u, v) + + +@register_complex(aten.allclose) +def allclose_impl( + input: torch.Tensor, + other: torch.Tensor, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + return torch.all( + torch.isclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan) + ).item() # type: ignore[bad-return] + + +@register_complex(aten.stack) +def stack_impl(self: list[ComplexTensor], *args, **kwargs) -> ComplexTensor: + re_im_tuples = [split_complex_arg(self_i) for self_i in self] + u = torch.stack([c[0] for c in re_im_tuples], *args, **kwargs) + v = torch.stack([c[1] for c in re_im_tuples], *args, **kwargs) + return ComplexTensor(u, v) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj_physical) +@register_complex(aten.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, -im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten._conj) +def _conj_impl(self: ComplexTensor) -> ComplexTensor: + re, im = split_complex_tensor(self) + return ComplexTensor(re, torch._neg_view(im)) + + +@register_complex(aten.index_add) +def index_add_impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add(dim, index, source_re) + ret_im = self_im.index_add(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.index_add_) +def index_add__impl( + self: ComplexTensor, dim: int, index: torch.Tensor, source: ComplexTensor, **kwargs +) -> ComplexTensor: + alpha = kwargs.pop("alpha", None) + if alpha is not None: + source = source * alpha + + self_re, self_im = split_complex_arg(self) + source_re, source_im = split_complex_arg(source) + + ret_re = self_re.index_add_(dim, index, source_re) + ret_im = self_im.index_add_(dim, index, source_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.masked_fill) +def masked_fill_impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill(mask, value_re) + ret_im = self_im.masked_fill(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +# TODO (hameerabbasi): Not being tested +@register_complex(aten.masked_fill_) +def masked_fill__impl( + self: ComplexTensor, mask: torch.Tensor, value: complex +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + value_re, value_im = split_complex_arg(value) + + ret_re = self_re.masked_fill_(mask, value_re) + ret_im = self_im.masked_fill_(mask, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.constant_pad_nd) +def constant_pad_nd_impl( + self: ComplexTensor, pad, value: complex | None = None +) -> ComplexTensor: + self_re, self_im = split_complex_tensor(self) + if value is None: + ret_re = aten.constant_pad_nd(self_re, pad) + ret_im = aten.constant_pad_nd(self_im, pad) + else: + value_re, value_im = split_complex_arg(value) + ret_re = aten.constant_pad_nd(self_re, pad, value_re) + ret_im = aten.constant_pad_nd(self_im, pad, value_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.var) +def var_impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor: + self_re, self_im = split_complex_tensor(self) + return torch.var(self_re, *args, **kwargs) + torch.var(self_im, *args, **kwargs) + + +@register_complex(aten.scatter_add) +def scatter_add_impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + ret_re = torch.scatter_add(self_re, dim, index, src_re) + ret_im = torch.scatter_add(self_im, dim, index, src_im) + + return ComplexTensor(ret_re, ret_im) + + +@register_complex(aten.scatter_add_) +def scatter_add__impl( + self: ComplexTensor, dim, index, src: ComplexTensor +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + src_re, src_im = split_complex_arg(src) + + out_re = self_re.scatter_add_(dim, index, src_re) + out_im = self_im.scatter_add_(dim, index, src_im) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.index_put_) +def index_put__impl( + self: ComplexTensor, + indices: tuple[torch.Tensor, ...], + values: ComplexTensor, + accumulate: bool = False, +) -> ComplexTensor: + self_re, self_im = split_complex_arg(self) + values_re, values_im = split_complex_arg(values) + + out_re = self_re.index_put_(indices, values_re, accumulate=accumulate) + out_im = self_im.index_put_(indices, values_im, accumulate=accumulate) + + return ComplexTensor(out_re, out_im) + + +@register_complex(aten.tanh_backward) +def tanh_backward(out_grad: torch.Tensor, y: torch.Tensor): + return out_grad * (1.0 - y * y).conj_physical() + + +@register_complex(aten.diagonal_backward) +def diagonal_backward( + grad_output: torch.Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int +): + grad_input = grad_output.new_zeros(input_sizes) + return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2) + + +def _dt_to_real(dt: torch.dtype | Any) -> torch.dtype | Any: + if not isinstance(dt, torch.dtype): + return dt + + return COMPLEX_TO_REAL[dt] + + +def register_to_impl(op: OpType): + """Register an op similar to `aten.to`, but may have different signatures.""" + + def impl(self: ComplexTensor, *args, **kwargs) -> torch.Tensor | ComplexTensor: + x, y = split_complex_tensor(self) + try: + args = tuple(_dt_to_real(a) for a in args) + kwargs = {k: _dt_to_real(v) for k, v in kwargs.items()} + except KeyError: + return op(x, *args, **kwargs) + + return ComplexTensor(op(x, *args, **kwargs), op(y, *args, **kwargs)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +to_impl = register_to_impl(aten.to) +_to_copy_impl = register_to_impl(aten._to_copy) diff --git a/torch/_subclasses/complex_tensor/_ops/common.py b/torch/_subclasses/complex_tensor/_ops/common.py new file mode 100644 index 0000000000000..88532efe224bb --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/common.py @@ -0,0 +1,317 @@ +from collections.abc import Callable +from typing import Any, overload, TypeAlias +from typing_extensions import TypeIs + +import torch +from torch import Tensor +from torch._decomp import get_decompositions +from torch._ops import OpOverload, OpOverloadPacket +from torch._refs import is_complex as _is_complex +from torch.types import Number +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from .._core import ComplexTensor + + +OpType: TypeAlias = OpOverloadPacket | OpOverload + +TableType: TypeAlias = dict[OpType, Callable] + +# Mapping from ops to implementations +COMPLEX_OPS_TABLE: TableType = {} + +COMPLEX_TO_REAL = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +REAL_TO_COMPLEX = {v: k for k, v in COMPLEX_TO_REAL.items()} + +# Used to promote dtypes in `promote_real_cpu_tensors` +PROMOTE_TYPES = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.complex32: torch.complex64, +} + + +def is_complex_tensor(obj: Any, /) -> TypeIs[ComplexTensor]: + r"""Returns True if the input is a ComplexTensor, else False + + Args: + a: any input + + Examples: + + >>> # xdoctest: +SKIP + >>> from torch.complex import ComplexTensor + >>> data = torch.zeros((3, 2), dtype=torch.complex64) + >>> ct = ComplexTensor.from_interleaved(data) + >>> is_complex_tensor(ct) + True + """ + return isinstance(obj, ComplexTensor) + + +@overload +def promote_tensors( + *tensors: ComplexTensor, +) -> tuple[torch.dtype, tuple[ComplexTensor, ...]]: ... + + +@overload +def promote_tensors( + *tensors: Tensor, +) -> tuple[torch.dtype, tuple[Tensor, ...]]: ... + + +def promote_tensors( + *tensors: Tensor | ComplexTensor, +) -> tuple[torch.dtype, tuple[Tensor | ComplexTensor, ...]]: + """ + Promotes all tensors to a common dtype. + Additionally promotes CPU tensors to at least `float32`. + """ + tensor = next(t for t in tensors if isinstance(t, Tensor)) + out_dt = tensor.dtype + for t in tensors: + if isinstance(t, Tensor): + out_dt = torch.promote_types(out_dt, t.dtype) + + prom_dt = PROMOTE_TYPES.get(out_dt, out_dt) + return out_dt, tuple( + t.to(prom_dt) if isinstance(t, Tensor) else torch.asarray(t, dtype=prom_dt) + for t in tensors + ) + + +def register_complex( + op: OpType, + func_impl: Callable | None = None, +): + """Decorator to register an implementation for some ops in some dispatch tables""" + + def inner(func): + if COMPLEX_OPS_TABLE.get(op, func) is not func: + raise RuntimeError(f"Attempted to register multiple functions for {op}") + COMPLEX_OPS_TABLE[op] = func + return func + + if func_impl is None: + return inner + + return inner(func_impl) + + +FORCE_TEST_LIST: list[OpType] = [] + + +def register_force_test(op: OpType, *args, **kwargs): + """Will attempt to test these ops even if they err on "normal" inputs""" + FORCE_TEST_LIST.append(op) + return register_complex(op, *args, **kwargs) + + +DECOMPOSITIONS = get_decompositions(list(torch.ops.aten)) # type: ignore[no-matching-overload] + + +def lookup_complex(func: OpOverload, *args, **kwargs) -> Callable | None: + """ + Lookup an impl from the table. + + Try the particular overload first, then the overload packet. + + If nothing is found, try the decompositions with both. + """ + return COMPLEX_OPS_TABLE.get( + func, + COMPLEX_OPS_TABLE.get( + func.overloadpacket, + DECOMPOSITIONS.get(func, DECOMPOSITIONS.get(func.overloadpacket)), + ), + ) + + +def is_complex(x: Any, /) -> bool: + """Utility to detect if a given object is (known) to be complex.""" + return (isinstance(x, Tensor) and _is_complex(x)) or isinstance(x, complex) + + +@overload +def split_complex_arg( + arg: Tensor | ComplexTensor, +) -> tuple[Tensor, Tensor]: ... + + +@overload +def split_complex_arg( + arg: complex | Number, +) -> tuple[Number, Number]: ... + + +def split_complex_arg( + arg: Tensor | ComplexTensor | complex | Number, +) -> tuple[Tensor, Tensor] | tuple[Number, Number]: + """ + Split a complex argument into a real/imaginary component. + + If real, use zero for the imaginary part. + """ + if isinstance(arg, ComplexTensor): + return split_complex_tensor(arg) + if isinstance(arg, Tensor): + if is_complex(arg): + return arg.real, arg.imag + return arg, torch.zeros_like(arg) + # TODO (hameerabbasi): Should there be a `torch.SymComplex`? + if isinstance(arg, complex): + return arg.real, arg.imag + if isinstance(arg, float | torch.SymFloat): + return arg, 0.0 + if isinstance(arg, int | torch.SymInt): + return arg, 0 + if isinstance(arg, bool | torch.SymBool): + return arg, False + raise TypeError(f"Expected tensor or number got, {type(arg)}") + + +def split_complex_tensor(complex_tensor: ComplexTensor) -> tuple[Tensor, Tensor]: + """Split a ComplexTensor into its real and imaginary parts.""" + return complex_tensor.re, complex_tensor.im + + +def complex_to_real_dtype(dtype: torch.dtype) -> torch.dtype: + """Convert a complex dtype to the dtype of its real part. Return other dtypes as-is.""" + return COMPLEX_TO_REAL.get(dtype, dtype) + + +def _get_op_name(op: OpType) -> str: + """Get the op name from the op.""" + if isinstance(op, OpOverload): + op = op.overloadpacket + return str(op).split(".", 1)[1] + + +def _get_func_name(op: OpType) -> str: + """Get the name of the implementation function from the op.""" + return f"{_get_op_name(op)}_impl" + + +def register_error(op: OpType, exc_type: type[Exception] = NotImplementedError): + msg = f"`aten.{_get_op_name(op)}` not implemented for `{ComplexTensor.__name__}`." + + def ordered_impl(*args, **kwargs): + raise exc_type(msg) + + func_name = _get_func_name(op) + ordered_impl.__name__ = func_name + ordered_impl.__qualname__ = func_name + + return register_force_test(op, ordered_impl) + + +def register_binary_nonlinear(op: OpType) -> Callable: + """Register a "multiplication-style" op, e.g. aten.mul, aten.mm, ...""" + + def impl(lhs: ComplexTensor, rhs: ComplexTensor, *args, **kwargs) -> ComplexTensor: + a_r, a_i = split_complex_arg(lhs) + b_r, b_i = split_complex_arg(rhs) + out_dt, (a_r, a_i, b_r, b_i) = promote_tensors(a_r, a_i, b_r, b_i) + real = op(a_r, b_r, *args, **kwargs) - op(a_i, b_i, *args, **kwargs) + imag = op(a_r, b_i, *args, **kwargs) + op(a_i, b_r, *args, **kwargs) + return ComplexTensor(real.to(out_dt), imag.to(out_dt)) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def register_simple(op: OpType): + """Register an op which can be applied independently to the real and complex parts to get the result.""" + + def impl( + self: ComplexTensor, *args, dtype: torch.dtype | None = None, **kwargs + ) -> ComplexTensor: + x, y = split_complex_tensor(self) + if dtype is not None and dtype not in COMPLEX_TO_REAL: + raise RuntimeError( + "Non-complex `dtype` specified, please write custom impl." + ) + + if dtype in COMPLEX_TO_REAL: + assert dtype is not None + kwargs["dtype"] = COMPLEX_TO_REAL[dtype] + + u = op(x, *args, **kwargs) + v = op(y, *args, **kwargs) + + u_flat, u_spec = tree_flatten(u) + v_flat, v_spec = tree_flatten(v) + assert u_spec == v_spec + out_flat = [ + ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False) + ] + return tree_unflatten(out_flat, u_spec) + + func_name = _get_func_name(op) + impl.__name__ = func_name + impl.__qualname__ = func_name + + return register_complex(op, impl) + + +def _as_complex_tensor(arg: Tensor | Any) -> Tensor | ComplexTensor | Any: + """Convert a Tensor with complex dtypes to a ComplexTensor. Pass along other args as-is.""" + if ( + not isinstance(arg, ComplexTensor) + and isinstance(arg, Tensor) + and arg.dtype in COMPLEX_TO_REAL + ): + return ComplexTensor.from_interleaved(arg) + return arg + + +def _as_interleaved(arg: ComplexTensor | Any) -> Tensor | Any: + """Convert a ComplexTensor to a Tensor with a complex dtype. Pass other arguments as-is.""" + if isinstance(arg, ComplexTensor): + return arg.as_interleaved() + return arg + + +class ComplexTensorMode(TorchDispatchMode): + _compile: bool + + """ A TorchDispatchMode to replace any Tensor that has a complex dtype with a ComplexTensor for the computation. """ + + def __init__(self, _dispatch_key=None, *, _compile: bool = False): + """Initialize a ComplexTensorMode. + + Args: + _dispatch_key: passed on to TorchDispatchMode + _compile: Compile the op before the computation + """ + super().__init__(_dispatch_key) + self._compile = _compile + + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[type], + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + if kwargs is None: + kwargs = {} + + # TODO (hameerabbasi): Test perf with `_compile` set to `True` + if self._compile: + func = torch.compile(func) # type: ignore[bad-assignment] + + args = tree_map(_as_complex_tensor, args) + kwargs = tree_map(_as_complex_tensor, kwargs) + + return tree_map(_as_interleaved, func(*args, **kwargs)) diff --git a/torch/_subclasses/complex_tensor/_ops/prims.py b/torch/_subclasses/complex_tensor/_ops/prims.py new file mode 100644 index 0000000000000..9a237b32d9904 --- /dev/null +++ b/torch/_subclasses/complex_tensor/_ops/prims.py @@ -0,0 +1,34 @@ +import torch + +from .._core import ComplexTensor +from .common import ( + complex_to_real_dtype, + register_complex, + register_force_test, + split_complex_tensor, +) + + +prims = torch.ops.prims +aten = torch.ops.aten + + +# TODO (hameerabbasi): Not being tested +@register_force_test(prims.convert_element_type) +def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor: + dtype = complex_to_real_dtype(dtype) + u, v = split_complex_tensor(x) + u_out = prims.convert_element_type(u, dtype) + v_out = prims.convert_element_type(v, dtype) + + return ComplexTensor(u_out, v_out) + + +@register_complex(prims.conj_physical) +def conj_physical_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj_physical(self) + + +@register_complex(prims.conj) +def conj_impl(self: ComplexTensor) -> ComplexTensor: + return aten._conj(self) diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index d580809460811..adba98beb2724 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -138,7 +138,7 @@ inline void PyErr_SetString(PyObject* type, const std::string& message) { throw; \ } \ } \ - catch (const std::exception& e) { \ + catch (const std::exception&) { \ torch::translate_exception_to_python(std::current_exception()); \ return retval; \ } diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index a13cc70270ccb..7470344cc05f7 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -390,31 +390,27 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ torch::profiler::impl::ActivityType::CPU}; -#if defined(USE_KINETO) && \ - (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) - if (at::hasMTIA()) { - activities.insert(torch::profiler::impl::ActivityType::MTIA); - } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#if defined(USE_KINETO) +#if (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) if (at::getNumGPUs() > 0) { activities.insert(torch::profiler::impl::ActivityType::CUDA); } -#elif defined(USE_KINETO) +#endif // (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) +#if (!defined(LIBKINETO_NOXPUPTI)) if (at::hasXPU()) { activities.insert(torch::profiler::impl::ActivityType::XPU); } - if (at::hasHPU()) { - activities.insert(torch::profiler::impl::ActivityType::HPU); - } +#endif // (!defined(LIBKINETO_NOXPUPTI)) if (at::hasMTIA()) { activities.insert(torch::profiler::impl::ActivityType::MTIA); } + if (at::hasHPU()) { + activities.insert(torch::profiler::impl::ActivityType::HPU); + } if (c10::get_privateuse1_backend() != "privateuseone") { activities.insert(torch::profiler::impl::ActivityType::PrivateUse1); } -#endif +#endif // defined(USE_KINETO) return activities; }); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 6d0bf5d0a8579..de7f3dc53c323 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1200,25 +1200,27 @@ get_thread_local_native_sharding_propagator_cache() { py::reinterpret_borrow(PyThreadState_GetDict()); // We need to clean up before Python detaches from the thread if // the thread is being destroyed. - thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = - py::capsule(new std::thread::id(this_thread_id), [](void* p) { - auto* ptid = reinterpret_cast(p); - { - std::lock_guard inner_lock( - native_sharding_propagator_cache_cleanup_mutex); - auto it = all_thread_caches.find(*ptid); - if (it != all_thread_caches.end()) { - // We need to both: - // 1) free python objects, and - it->second->reset(); - // 2) make sure we don't try to come back and mess with - // a destroyed thread-local at module unload (e.g., - // process exit) time. - all_thread_caches.erase(it); + if (!thread_dict.contains("__DTensor_fastpath_thread_cache_cleanup")) { + thread_dict["__DTensor_fastpath_thread_cache_cleanup"] = + py::capsule(new std::thread::id(this_thread_id), [](void* p) { + auto* ptid = reinterpret_cast(p); + { + std::lock_guard inner_lock( + native_sharding_propagator_cache_cleanup_mutex); + auto it = all_thread_caches.find(*ptid); + if (it != all_thread_caches.end()) { + // We need to both: + // 1) free python objects, and + it->second->reset(); + // 2) make sure we don't try to come back and mess with + // a destroyed thread-local at module unload (e.g., + // process exit) time. + all_thread_caches.erase(it); + } } - } - delete ptid; - }); + delete ptid; + }); + } } return native_sharding_propagator_cache_DO_NOT_USE.value(); } diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 9f79a09d236e5..b888e315021ac 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -81,7 +81,7 @@ c10::intrusive_ptr ProcessGroup::getBackend( ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED}; try { backendType = deviceTypeToBackendType_.at(deviceType); - } catch (const std::out_of_range& e) { + } catch (const std::out_of_range&) { TORCH_CHECK( false, "No backend type associated with device type ", deviceType); } diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index f3ff9e623043e..7427848b8445b 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -246,7 +246,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } @@ -322,7 +322,7 @@ class UvTcpServer : public UvTcpSocket { uv_err_name(uv_res), uv_strerror(uv_res))); res->cacheSocketPort(); - } catch (std::exception& ex) { + } catch (std::exception&) { res->close(); throw; } diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp index 10274d053b995..fe8f831a23bb1 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.cpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) { res.setStatus(200); }}; +RegisterHandler frTracehandler( + "fr_trace_json", + [](const Request&, Response& res) { + auto trace = ::c10d::dump_fr_trace_json(true, true); + res.setContent(std::move(trace), "application/json"); + res.setStatus(200); + }); + } // namespace void registerHandler(const std::string& name, HandlerFunc f) { diff --git a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp index 70333a3a4844c..58ae9368ea212 100644 --- a/torch/csrc/distributed/c10d/control_plane/Handlers.hpp +++ b/torch/csrc/distributed/c10d/control_plane/Handlers.hpp @@ -18,6 +18,14 @@ class TORCH_API Request { virtual const std::string& body() const = 0; virtual const std::multimap& params() const = 0; + + std::string getParam(const std::string& key) const { + auto it = params().find(key); + if (it != params().end()) { + return it->second; + } + return ""; + } }; // Response represents a response to the handler. This conceptually maps to an diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 8bbe857620790..eda6ee3a91488 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) { TORCH_CHECK( server_.bind_to_port(hostOrFile, 80), fmt::format("Error binding to {}", hostOrFile)); + } else if (port == 0) { + C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); + port_ = server_.bind_to_any_port(hostOrFile); + TORCH_CHECK( + port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port)); } else { C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port); TORCH_CHECK( server_.bind_to_port(hostOrFile, port), fmt::format("Error binding to {}:{}", hostOrFile, port)); + port_ = port; } serverThread_ = std::thread([this]() { diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp index 41c1356fc01f3..20d05b7509e92 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp @@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target { void shutdown(); + int port() { + return port_; + } + private: httplib::Server server_; std::thread serverThread_; + int port_; }; } // namespace c10d::control_plane diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 94a8c0bbe228b..255e793eaa4df 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include #include @@ -4209,7 +4210,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. }), py::arg("host_or_file"), py::arg("port") = -1) - .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown); + .def("shutdown", &::c10d::control_plane::WorkerServer::shutdown) + .def_property_readonly( + "port", &::c10d::control_plane::WorkerServer::port); module.def( "_get_handler", @@ -4225,6 +4228,25 @@ such as `dist.all_reduce(tensor, async_op=True)`. Returns the handler with the specified name. )"); + module.def( + "_register_handler", + [](const std::string& name, const py::function& handler) { + ::c10d::control_plane::registerHandler( + name, + [handler]( + const ::c10d::control_plane::Request& req, + ::c10d::control_plane::Response& res) { + py::gil_scoped_acquire acquire; + handler(std::ref(req), std::ref(res)); + }); + }, + + py::arg("name"), + py::arg("handler"), + R"( + Registers a handler by name. + )"); + module.def( "_get_handler_names", &::c10d::control_plane::getHandlerNames, @@ -4242,12 +4264,9 @@ such as `dist.all_reduce(tensor, async_op=True)`. // Default constructor. .def(py::init<>()) .def("body", &::c10d::control_plane::Request::body) - .def("params", &::c10d::control_plane::Request::params); + .def("get_param", &::c10d::control_plane::Request::getParam); - py::class_< - ::c10d::control_plane::Response, - std::shared_ptr<::c10d::control_plane::Response>, - PythonResponse>( + py::class_<::c10d::control_plane::Response, PythonResponse>( module, "_Response", R"( diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index f83d42df4ac68..6352330c3872c 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -517,11 +517,6 @@ static void init_multicast_for_block( using McHandleType = std::conditional_t; - McHandleType invalidator; - std::memset(&invalidator, UINT8_MAX, sizeof(McHandleType)); - - // Phase 1: export handle (rank 0 only) - McHandleType mc_exported_handle{}; if (rank == 0) { CUmulticastObjectProp mc_prop{}; mc_prop.numDevices = world_size; @@ -530,82 +525,68 @@ static void init_multicast_for_block( // create a multicast object, which acts as a handle that allows multiple // devices or processes to access the same memory allocation coherently. - try { - C10_CUDA_DRIVER_CHECK( - driver_api->cuMulticastCreate_(&mc_handle, &mc_prop)); - // using the CUDA Driver API to export a multicast object into a POSIX file - // descriptor. - C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( - &mc_exported_handle, mc_handle, handleType, 0)); - } catch (const std::exception& e) { - // Allow peers gracefully skip multicast initialization by sending -1 - mc_exported_handle = invalidator; + auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop); + if (err != CUDA_SUCCESS) { + const char* err_str; + CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str); + if (get_error_str_err != CUDA_SUCCESS) { + err_str = "unknown cuda driver error"; + } LOG(WARNING) - << "SymmetricMemory: fail to export multicast handle.\n" - << e.what(); + << "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str + << "\". Gracefully skipping multicast initialization. " + << "However, this is unexpected. Please report the issue on GitHub."; + // Allow peers gracefully skip multicast initialization by sending -1 + // TODO: allow graceful skip for fabric + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, -1); + } + return; } - } - - // Phase 2: Exchange handle - McHandleType recv_handle; - if constexpr (!use_fabric_handle) { - recv_handle = ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); - } else { - // TODO implement storeExchange.broadcast - auto gathered_handles = storeExchange.all_gather(store, rank, world_size, mc_exported_handle); - recv_handle = std::move(gathered_handles[0]); - } - - // Check exchange result - if (memcmp(&recv_handle, &invalidator, sizeof(McHandleType)) == 0) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } - // Flip to true after all CUDA steps finish - bool success_end = false; + McHandleType mc_exported_handle; + // using the CUDA Driver API to export a multicast object into a POSIX file + // descriptor. + C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_( + &mc_exported_handle, mc_handle, handleType, 0)); + if constexpr (!use_fabric_handle) { + ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle); + // Ref count is incremented as soon as SCM_RIGHTS send happens + close(mc_exported_handle); + } else { + // TODO implement storeExchange.broadcast + storeExchange.all_gather(store, rank, world_size, mc_exported_handle); + } - // Phase 3: Import handle (non-0 ranks only) - if (rank != 0) { + } else { if constexpr (!use_fabric_handle) { + int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1); + if (mc_fd == -1) { + return; + } // Convert back to a handle from the broadcasted POSIX file descriptor. - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( &mc_handle, - (void*)(uintptr_t)recv_handle, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all); + (void*)(uintptr_t)mc_fd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(mc_fd); } else { - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_( - &mc_handle, (void*)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all); + CUmemFabricHandle null_handle{}; + auto mc_handles = + storeExchange.all_gather(store, rank, world_size, null_handle); + C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( + &mc_handle, (void*)&(mc_handles[0]), CU_MEM_HANDLE_TYPE_FABRIC)); } } - // Phase 4: Bind memory // All rank adds their physical allocation to the multicast object - C10_CUDA_DRIVER_CHECK_GOTO( - driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx), check_all); - C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMulticastBindMem_( - mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0), check_all); - - success_end = true; - -check_all: - // Whether all ranks have succeeded - bool all_succeed = true; - auto rank_successes = storeExchange.all_gather(store, rank, world_size, success_end); - for (int r = 0; r < world_size; ++r) { - all_succeed &= rank_successes[r]; - } - // Close the file descriptor before exit - if constexpr (!use_fabric_handle) { - close(recv_handle); - } - if (!all_succeed) { - LOG(WARNING) << "Gracefully skipping multicast initialization."; - return; - } + C10_CUDA_DRIVER_CHECK( + driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx)); + C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_( + mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0)); - // Phase 5: Map to virtual memory map_block(&mc_addr, mc_handle, block->block_size, block->device_idx); + storeExchange.barrier(store, rank, world_size); #endif } diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 11659cc24eb89..117324796e7f8 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -353,7 +353,7 @@ static PyObject* NodeBase__update_args_kwargs( Py_CLEAR(node->_kwargs); node->_kwargs = map_aggregate(args[1], visit_fn); Py_RETURN_NONE; - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -397,7 +397,7 @@ static PyObject* NodeBase__replace_input_with( PyObject* update_args[2] = {new_args.get(), new_kwargs.get()}; return NodeBase__update_args_kwargs(self, update_args, 2); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; } } @@ -802,7 +802,7 @@ static PyObject* py_map_aggregate( // args[0]: aggregate, args[1]: callable fn return map_aggregate( args[0], [fn](PyObject* a) { return PyObject_CallOneArg(fn, a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } @@ -824,7 +824,7 @@ static PyObject* py_map_arg( } return Py_NewRef(a); }); - } catch (const PythonError& e) { + } catch (const PythonError&) { return nullptr; // error should already be set } } diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 4fb746ea15271..2eda2b218e705 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -38,9 +38,9 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check -#include -#include -#include +#include +#include +#include #ifdef __cplusplus extern "C" { diff --git a/torch/csrc/jit/python/pybind.h b/torch/csrc/jit/python/pybind.h index 066ff7f77f56c..845beb540c9f1 100644 --- a/torch/csrc/jit/python/pybind.h +++ b/torch/csrc/jit/python/pybind.h @@ -117,7 +117,7 @@ struct type_caster { try { value = torch::jit::toTypeInferredIValue(src); return true; - } catch (std::exception& e) { + } catch (std::exception&) { return false; } } @@ -142,7 +142,7 @@ struct type_caster { std::string src_str; try { src_str = py::cast(src); - } catch (std::exception& e) { + } catch (std::exception&) { return false; } value = torch::jit::Symbol::fromQualString(src_str); diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 15ac8e539e76b..0e09eeb7f7b14 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -281,11 +281,11 @@ struct FromImpl> { TORCH_ERROR_CODE_CHECK( torch_new_list_reserve_size(val.size(), &new_list_handle)); for (const auto& elem : val) { - TORCH_ERROR_CODE_CHECK( - torch_list_push_back(new_list_handle, from(elem))); + TORCH_ERROR_CODE_CHECK(torch_list_push_back( + new_list_handle, torch::stable::detail::from(elem))); } - return from(new_list_handle); - } catch (const std::runtime_error& e) { + return torch::stable::detail::from(new_list_handle); + } catch (const std::runtime_error&) { if (new_list_handle != nullptr) { // clean up memory if an error was thrown TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle)); @@ -553,7 +553,7 @@ struct ToImpl> { } TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); return result; - } catch (const std::runtime_error& e) { + } catch (const std::runtime_error&) { // clean up memory if an exception is thrown, and rethrow TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); throw; diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index c186694df94e7..db03d26227911 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -76,7 +76,11 @@ from torch.nested._internal.nested_int import NestedIntNode from torch.utils import _pytree as pytree from torch.utils._mode_utils import no_dispatch -from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode_stack, + return_and_correct_aliasing, + TorchDispatchMode, +) from torch.utils.checkpoint import get_device_states, set_device_states @@ -86,6 +90,12 @@ from . import _c10d +def _is_in_fake_tensor_mode() -> bool: + return any( + isinstance(mode, FakeTensorMode) for mode in _get_current_dispatch_mode_stack() + ) + + def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: return ( isinstance(op, OpOverload) @@ -256,21 +266,31 @@ def _for_each_rank_run_func( a.wait() if isinstance(a, AsyncCollectiveTensor) else a for a in flat_args ] - # NB: Before invoking an op we are collecting rng states from CPU and - # CUDA devices such that we can reset to the same before invoking op - # for each rank. This is not very efficient and will likely be revisited - # to support per rank rng state. - rng_state = _get_rng_state() + lm = enabled_local_tensor_mode() + use_per_rank_rng = lm is not None and len(lm._per_rank_rng_states) > 0 + + global_rng_state = None if use_per_rank_rng else _get_rng_state() + flat_rank_rets = {} default_value: Tensor | None = None for r in sorted(ranks): - _set_rng_state(*rng_state) + if use_per_rank_rng: + assert lm is not None + _set_rng_state(*lm._per_rank_rng_states[r]) + else: + assert global_rng_state is not None + _set_rng_state(*global_rng_state) + rank_flat_args = [_map_to_rank_local_val(a, r) for a in flat_args] rank_args, rank_kwargs = pytree.tree_unflatten(rank_flat_args, args_spec) rank_ret = func(*rank_args, **rank_kwargs) flat_rank_rets[r] = rank_ret + if use_per_rank_rng: + assert lm is not None + lm._per_rank_rng_states[r] = _get_rng_state() + if default_value is None and func is torch.ops.aten.split.Tensor: # If split happens over the dimension smaller than the number of chunks # it is possible that some ranks will produce shorter lists of chunks. @@ -437,6 +457,247 @@ def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": return ConstantIntNode(num) +class _LocalDeviceHandle: + """ + Wrapper around device module (e.g., torch.cuda) with automatic LocalTensor semantics. + + This class wraps device modules and automatically handles per-rank operations in + LocalTensor mode: + - get_rng_state() returns a LocalTensor with per-rank states + - set_rng_state(LocalTensor) sets per-rank states + + When not in LocalTensor mode, it delegates directly to the underlying device handle. + """ + + def __init__(self, device_handle, device_type: str): + """ + Initialize the local device handle wrapper. + + Args: + device_handle: The underlying device module (e.g., torch.cuda) + device_type: Device type string (e.g., "cuda", "cpu") + """ + self._device_handle = device_handle + self._device_type = device_type + + def get_rng_state(self): + """ + Get RNG state, automatically returning LocalTensor in LocalTensor mode. + + Returns: + LocalTensor in LocalTensor mode, regular Tensor otherwise + """ + lm = enabled_local_tensor_mode() + if not lm: + return self._device_handle.get_rng_state() + + original_state = _get_rng_state() + per_rank_states = {} + + try: + for rank in lm.ranks: + # We need to set-then-get instead of directly copying lm._per_rank_rng_states[rank] + # because they have different structures: + # - lm._per_rank_rng_states[rank] is a tuple: (cpu_state, {device_idx: cuda_state}) + # - self._device_handle.get_rng_state() returns just the device-specific tensor + # So we temporarily restore the full RNG state (CPU + all CUDA devices) for this rank, + # then extract only the specific device's state tensor that we need. + if rank in lm._per_rank_rng_states: + _set_rng_state(*lm._per_rank_rng_states[rank]) + + per_rank_states[rank] = self._device_handle.get_rng_state() + finally: + _set_rng_state(*original_state) + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return LocalTensor(per_rank_states) + + def set_rng_state(self, state): + """ + Set RNG state, automatically handling LocalTensor input. + + Args: + state: Regular Tensor or LocalTensor with per-rank states + """ + if isinstance(state, LocalTensor): + lm = enabled_local_tensor_mode() + assert lm is not None + + # Similar to get_rng_state but in reverse: we need to convert from + # device-specific tensor format to full state tuple format. + # - state._local_tensors[rank] contains just the device-specific RNG state tensor + # - lm._per_rank_rng_states[rank] needs a tuple: (cpu_state, {device_idx: cuda_state}) + # So we set the device's state with the rank-specific tensor, then _get_rng_state() + # captures both CPU and CUDA states into the tuple format that _per_rank_rng_states expects. + for rank, rank_state in state._local_tensors.items(): + self._device_handle.set_rng_state(rank_state.to("cpu")) + lm._per_rank_rng_states[rank] = _get_rng_state() + else: + self._device_handle.set_rng_state(state.to("cpu")) + + def __getattr__(self, name): + """Delegate all other attributes to the underlying device module.""" + return getattr(self._device_handle, name) + + +class _LocalOffsetBasedRNGTracker: + """ + LocalTensor-specific RNG tracker for DTensor random operations. + + This class manages per-rank RNG states when running in LocalTensor mode, + using _LocalPhiloxState to track different offsets for each virtual rank. + It is instantiated and used by OffsetBasedRNGTracker when in LocalTensor mode. + + Much of this is derived from OffsetBasedRNGTracker: + https://github.com/pytorch/pytorch/blob/402c46503002f98ccfc023a733081fb0719223a1/torch/distributed/tensor/_random.py#L182 + """ + + def __init__(self, device_type: str = "cuda"): + """Initialize the LocalTensor RNG tracker.""" + from torch.distributed.device_mesh import _get_device_handle + + self._device_type = device_type + self._device_handle = _LocalDeviceHandle( + _get_device_handle(device_type), device_type + ) + self.distribute_region_enabled = True + self._device_mesh = None + + @property + def _device(self): + return torch.device(self._device_type, torch.cuda.current_device()) + + def _set_pre_op_offset(self, state, spec) -> None: + """Compute and set per-rank offsets before the random operation.""" + from torch.distributed.tensor._ops.utils import prod + from torch.distributed.tensor._utils import ( + _compute_local_shape_and_global_offset, + ) + from torch.distributed.tensor.placement_types import Shard + + lm = enabled_local_tensor_mode() + assert lm is not None + + state._per_rank_offsets = {} + + for rank in lm.ranks: + # compute this rank's coordinate in the mesh + mesh_coords = [] + for mesh_dim_idx in range(spec.mesh.ndim): + mesh_dim_size = spec.mesh.size(mesh_dim_idx) + # calculate rank's coordinate in this mesh dimension + num_chunks_after = 1 + for j in range(mesh_dim_idx + 1, spec.mesh.ndim): + num_chunks_after *= spec.mesh.size(j) + coord = (rank // num_chunks_after) % mesh_dim_size + mesh_coords.append(coord) + + # compute local shape and global offset for this rank + local_shape, global_offset = _compute_local_shape_and_global_offset( + spec.shape, spec.mesh.shape, mesh_coords, spec.placements + ) + + # compute shard offset based on placements + shard_offset = 1 + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + shard_offset *= global_offset[shard_dim] + 1 + + # get current offset for this rank + current_offset = int( + state._per_rank_states[rank][8:].view(dtype=torch.int64).item() + ) + + # compute local size + local_size = prod(local_shape) + + # compute new offset (must be multiple of 4) + shard_linear_idx = shard_offset - 1 + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + state._per_rank_offsets[rank] = current_offset + offset_incr + + def _set_post_op_offset(self, state, spec, old_offset) -> None: + """Set per-rank offsets after the random operation.""" + from torch.distributed.tensor._ops.utils import prod + + lm = enabled_local_tensor_mode() + assert lm is not None + + dtensor_shape = spec.shape + numel = prod(dtensor_shape) + # offset must be multiple of 4 + numel = (numel + 3) // 4 * 4 + + if not hasattr(state, "_per_rank_offsets"): + state._per_rank_offsets = {} + + # handle LocalIntNode old_offset (different values per rank) + if isinstance(old_offset, SymInt) and isinstance(old_offset.node, LocalIntNode): + for rank in lm.ranks: + rank_old_offset = old_offset.node._local_ints[rank] + state._per_rank_offsets[rank] = rank_old_offset + numel + else: + # same old_offset for all ranks + old_offset_int = ( + int(old_offset) if isinstance(old_offset, SymInt) else old_offset + ) + for rank in lm.ranks: + state._per_rank_offsets[rank] = old_offset_int + numel + + @contextlib.contextmanager + def _distribute_region(self, spec, generator=None): + """Context manager for LocalTensor mode distribute region.""" + lm = enabled_local_tensor_mode() + assert lm is not None + + # get base state + if generator is not None: + base_state_tensor = generator.get_state() + per_rank_states = {rank: base_state_tensor.clone() for rank in lm.ranks} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + base_state_tensor = LocalTensor(per_rank_states) + else: + base_state_tensor = self._device_handle.get_rng_state() + + state = _LocalPhiloxState(base_state_tensor) + + if self.distribute_region_enabled: + # sync to rank 0's state if no explicit generator + if generator is None: + rank_0_state = lm._per_rank_rng_states[0] + rank_0_cpu, rank_0_cuda = rank_0_state + + if self._device.type == "cuda": + assert self._device.index in rank_0_cuda + rank_0_device_state = rank_0_cuda[self._device.index] + else: + rank_0_device_state = rank_0_cpu + + from torch.distributed.tensor._random import _PhiloxState + + rank_0_philox = _PhiloxState(rank_0_device_state) + state.seed = rank_0_philox.seed + state.offset = rank_0_philox.offset + + old_offset = state.offset + self._set_pre_op_offset(state, spec) + state.apply_to_local_tensor_mode(self._device_handle) + + try: + yield + finally: + self._set_post_op_offset(state, spec, old_offset) + state.apply_to_local_tensor_mode(self._device_handle) + else: + yield + + # maybe reset generator to rank 0's state + if generator is not None: + rank_0_state = state._per_rank_states[0] + generator.set_state(rank_0_state) + + _LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" @@ -597,6 +858,7 @@ def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": local_tensors_copy = { r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors_copy, self.requires_grad) def __repr__(self) -> str: # type: ignore[override] @@ -636,6 +898,7 @@ def __tensor_unflatten__( local_tensors = { _from_local_tensor_attr(a): t for a, t in inner_tensors.items() } + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors) @classmethod @@ -774,12 +1037,28 @@ def __init__(self, ranks: Union[int, frozenset[int]]): self.ranks = ranks self._disable = False self._old_get_coordinate = None + self._old_torch_manual_seed: Any = None + self._old_torch_initial_seed: Any = None + self._per_rank_rng_states: dict[ + int, tuple[torch.Tensor, dict[int, torch.Tensor]] + ] = {} def __enter__(self) -> "LocalTensorMode": self._disable = False self._patch_device_mesh() + self._patch_random_functions() _LOCAL_TENSOR_MODE.append(self) + # _distribute_region will compute correct per-shard offsets + # but we want all ranks to start with the same state + if not _is_in_fake_tensor_mode(): + cpu_state, cuda_states = _get_rng_state() + for rank in self.ranks: + self._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + return super().__enter__() def __exit__( @@ -790,6 +1069,7 @@ def __exit__( ) -> None: self._disable = True self._unpatch_device_mesh() + self._unpatch_random_functions() _LOCAL_TENSOR_MODE.pop() super().__exit__(exc_type, exc_val, exc_tb) @@ -936,6 +1216,7 @@ def tensor_map( m = cb(r, tensor._local_tensors[r]) if m is not None: results[r] = m + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(results) def _patch_device_mesh(self) -> None: @@ -949,6 +1230,87 @@ def _unpatch_device_mesh(self) -> None: # pyrefly: ignore [bad-assignment] self._old_get_coordinate = None + def _patch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is None: + self._old_torch_manual_seed = torch.random.manual_seed + torch.random.manual_seed = _LocalRandom.torch_manual_seed + torch.manual_seed = _LocalRandom.torch_manual_seed + + if self._old_torch_initial_seed is None: + self._old_torch_initial_seed = torch.random.initial_seed + torch.random.initial_seed = _LocalRandom.torch_initial_seed + torch.initial_seed = _LocalRandom.torch_initial_seed + + def _unpatch_random_functions(self) -> None: + import torch.random + from torch.distributed.tensor import _random as dtensor_random + + if self._old_torch_manual_seed is not None: + torch.random.manual_seed = self._old_torch_manual_seed + torch.manual_seed = self._old_torch_manual_seed + self._old_torch_manual_seed = None + + if self._old_torch_initial_seed is not None: + torch.random.initial_seed = self._old_torch_initial_seed + torch.initial_seed = self._old_torch_initial_seed + self._old_torch_initial_seed = None + + +class _LocalRandom: + """ + Holds implementations of random functionality that must be patched while running + under LocalTensorMode. + """ + + @staticmethod + def torch_manual_seed(seed) -> torch._C.Generator: + """LocalTensor-aware version of torch.random.manual_seed.""" + if ( + (lm := enabled_local_tensor_mode()) + and isinstance(seed, torch.SymInt) + and isinstance(seed.node, LocalIntNode) + ): + from torch.random import _manual_seed_impl + + for rank in sorted(lm.ranks): + rank_seed = seed.node._local_ints[rank] + _manual_seed_impl(rank_seed, update_local_tensor_states=False) + lm._per_rank_rng_states[rank] = _get_rng_state() + return torch.random.default_generator + from torch.random import _manual_seed_impl + + result = _manual_seed_impl(seed, update_local_tensor_states=False) + + if lm is not None and len(lm._per_rank_rng_states) > 0: + cpu_state, cuda_states = _get_rng_state() + for rank in lm.ranks: + lm._per_rank_rng_states[rank] = ( + cpu_state.clone(), + {idx: state.clone() for idx, state in cuda_states.items()}, + ) + + return result + + @staticmethod + def torch_initial_seed(): + """LocalTensor-aware version of torch.random.initial_seed.""" + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + return torch.random.default_generator.initial_seed() + rank_seeds = {} + + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = torch.random.default_generator.initial_seed() + + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + + return torch.random.default_generator.initial_seed() + class _LocalDeviceMesh: """ @@ -963,7 +1325,7 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]: # doing this because when submesh is created it is created for a particular # rank (therefore below we are patching get_rank method). We are trying to # limit the invasiveness of local tensor. - lm = local_tensor_mode() + lm = enabled_local_tensor_mode() assert lm is not None, "Unexpectedly not in LocalTensorMode" coords: list[dict[int, int]] = [{} for _ in range(self.ndim)] @@ -1024,6 +1386,22 @@ def local_tensor_mode() -> Optional[LocalTensorMode]: return None +def enabled_local_tensor_mode() -> Optional[LocalTensorMode]: + """ + Returns the current active LocalTensorMode only if it's enabled. + + This is a convenience function that combines the common pattern of checking + if local_tensor_mode() is not None and not disabled. + + Returns: + Optional[LocalTensorMode]: The current LocalTensorMode if active and enabled, else None. + """ + lm = local_tensor_mode() + if lm is not None and not lm._disable: + return lm + return None + + def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: """ Decorator that ensures a function is executed for each local tensor shard @@ -1048,8 +1426,7 @@ def maybe_run_for_local_tensor(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] - lm = local_tensor_mode() - if lm is None or lm._disable: + if not (lm := enabled_local_tensor_mode()): return func(*args, **kwargs) ret = None with lm.disable(): @@ -1068,6 +1445,73 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: return lm.disable() if lm is not None else contextlib.nullcontext() +def maybe_enable_local_tracker( + device_type: str, distribute_region_enabled: bool, spec, generator +): + """ + Returns a context manager for LocalTensor-mode RNG tracking if local tensor mode is enabled. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + distribute_region_enabled: Whether distribute region is enabled + spec: The DTensorSpec + generator: Optional torch.Generator + + Returns: + Context manager from local_tracker._distribute_region if local tensor mode is enabled, + otherwise None. + """ + if enabled_local_tensor_mode(): + local_tracker = _LocalOffsetBasedRNGTracker(device_type) + local_tracker.distribute_region_enabled = distribute_region_enabled + return local_tracker._distribute_region(spec, generator) + + return None + + +def get_generator_seed_for_device_type(device_type: str): + """ + Gets the generator seed for a specific device type, handling LocalTensor mode appropriately. + + Args: + device_type: The device type (e.g., "cuda", "cpu") + + Returns: + If in LocalTensor mode with per-rank RNG states: + - Returns int if all ranks have the same seed + - Returns SymInt(LocalIntNode) if ranks have different seeds + Otherwise: + - Returns int seed from the device's RNG state + """ + if lm := enabled_local_tensor_mode(): + if len(lm._per_rank_rng_states) == 0: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + device_module = torch.get_device_module(device_type) + + original_state = _get_rng_state() + + rank_seeds = {} + try: + for rank in sorted(lm.ranks): + _set_rng_state(*lm._per_rank_rng_states[rank]) + rank_seeds[rank] = int( + device_module.get_rng_state()[:8].view(torch.int64).item() + ) + finally: + # restore original state + _set_rng_state(*original_state) + + unique_seeds = set(rank_seeds.values()) + if len(unique_seeds) == 1: + return next(iter(unique_seeds)) + local_int_node = LocalIntNode(rank_seeds) + return torch.SymInt(local_int_node) + else: + device_module = torch.get_device_module(device_type) + return device_module.get_rng_state()[:8].view(torch.int64).item() + + import threading from queue import Queue @@ -1183,3 +1627,114 @@ def current() -> "LocalRunnerMode": global _LOCAL_RUNNER_MODE assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" return _LOCAL_RUNNER_MODE + + +class _LocalPhiloxState: + """ + LocalTensor-aware version of _PhiloxState that manages per-rank RNG states. + This class handles the case where the generator state is a LocalTensor, allowing + different offsets and seeds for different virtual ranks. + + Note: This is designed to be used as a drop-in replacement for _PhiloxState + when working with LocalTensors in the DTensor random ops implementation. + """ + + def __init__(self, state: torch.Tensor): + assert isinstance(state, LocalTensor), ( + "_LocalPhiloxState requires a LocalTensor" + ) + self._local_tensor = state + self._per_rank_states = { + rank: local_state.to("cpu") + for rank, local_state in state._local_tensors.items() + } + + @property + def state(self): + return LocalTensor(self._per_rank_states) # type: ignore[name-defined] + + @property + def offset(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + offsets = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + offsets[rank] = rank_philox.offset + + if len(set(offsets.values())) == 1: + return next(iter(offsets.values())) + # pyrefly: ignore [bad-argument-type, bad-argument-count] + return SymInt(LocalIntNode(offsets)) + + @offset.setter + def offset(self, offset: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(offset, SymInt) and isinstance(offset.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_offset = offset.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.offset = rank_offset + else: + offset_int = int(offset) if isinstance(offset, SymInt) else offset + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.offset = offset_int + + @property + def seed(self) -> Union[int, SymInt]: + from torch.distributed.tensor._random import _PhiloxState + + seeds = {} + for rank, state in self._per_rank_states.items(): + rank_philox = _PhiloxState(state) + seeds[rank] = rank_philox.seed + + if len(set(seeds.values())) == 1: + return next(iter(seeds.values())) + return SymInt(LocalIntNode(seeds)) + + @seed.setter + def seed(self, seed: Union[int, SymInt]) -> None: + from torch.distributed.tensor._random import _PhiloxState + + if isinstance(seed, SymInt) and isinstance(seed.node, LocalIntNode): + for rank, state in self._per_rank_states.items(): + rank_seed = seed.node._local_ints[rank] + rank_philox = _PhiloxState(state) + rank_philox.seed = rank_seed + else: + seed_int = int(seed) if isinstance(seed, SymInt) else seed + for state in self._per_rank_states.values(): + rank_philox = _PhiloxState(state) + rank_philox.seed = seed_int + + def apply_to_local_tensor_mode(self, device_handle) -> None: + """ + Apply per-rank RNG states to the LocalTensorMode's tracked states. + This updates both the device RNG state and the LocalTensorMode's _per_rank_rng_states. + + Args: + device_handle: The device handle to use for setting RNG state (_LocalDeviceHandle) + """ + if not enabled_local_tensor_mode(): + return + + assert hasattr(self, "_per_rank_offsets") + + for rank in sorted(self._per_rank_states.keys()): + offset_value = self._per_rank_offsets[rank] + if isinstance(offset_value, SymInt): + if isinstance(offset_value.node, LocalIntNode): + offset_value = offset_value.node._local_ints[rank] + else: + offset_value = int(offset_value) + + offset_tensor = torch.tensor( + [offset_value], dtype=torch.uint64, device="cpu" + ).view(torch.uint8) + self._per_rank_states[rank][8:] = offset_tensor + + # pyrefly: ignore [bad-argument-type, bad-argument-count] + device_handle.set_rng_state(LocalTensor(self._per_rank_states)) diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py new file mode 100644 index 0000000000000..46267a686e86d --- /dev/null +++ b/torch/distributed/debug/__init__.py @@ -0,0 +1,82 @@ +import logging +import multiprocessing +import socket + +# import for registration side effect +import torch.distributed.debug._handlers # noqa: F401 +from torch._C._distributed_c10d import _WorkerServer +from torch.distributed.debug._store import get_rank, tcpstore_client + + +__all__ = [ + "start_debug_server", + "stop_debug_server", +] + +logger: logging.Logger = logging.getLogger(__name__) + +_WORKER_SERVER: _WorkerServer | None = None +_DEBUG_SERVER_PROC: multiprocessing.Process | None = None + + +def start_debug_server(port: int = 25999, worker_port: int = 0) -> None: + """ + Start the debug server stack on all workers. The frontend debug server is + only started on rank0 while the per rank worker servers are started on all + ranks. + + This server provides an HTTP frontend that allows for debugging slow and + deadlocked distributed jobs across all ranks simultaneously. This collects + data such as stack traces, FlightRecorder events, and performance profiles. + + WARNING: This is intended to only be used in trusted network environments. + The debug server is not designed to be secure and should not be exposed to + the public internet. See SECURITY.md for more details. + + WARNING: This is an experimental feature and may change at any time. + + Args: + port (int): The port to start the frontend debug server on. + worker_port (int): The port to start the worker server on. Defaults to 0, which + will cause the worker server to bind to an ephemeral port. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _WORKER_SERVER is None, "debug server already started" + assert _DEBUG_SERVER_PROC is None, "debug server already started" + + logger.info("Starting debug server on port %d", port) + + store = tcpstore_client() + + _WORKER_SERVER = _WorkerServer("::", worker_port) + + RANK = get_rank() + store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_WORKER_SERVER.port}") + + from torch.distributed.debug._frontend import main + + if RANK == 0: + _DEBUG_SERVER_PROC = multiprocessing.Process( + target=main, args=(port,), daemon=True + ) + _DEBUG_SERVER_PROC.start() + + +def stop_debug_server() -> None: + """ + Shutdown the debug server and stop the frontend debug server process. + """ + global _WORKER_SERVER, _DEBUG_SERVER_PROC + + assert _DEBUG_SERVER_PROC is not None + assert _WORKER_SERVER is not None + + logger.info("Stopping debug server") + + _DEBUG_SERVER_PROC.terminate() + _WORKER_SERVER.shutdown() + _DEBUG_SERVER_PROC.join() + + _WORKER_SERVER = None + _DEBUG_SERVER_PROC = None diff --git a/torch/distributed/debug/_frontend.py b/torch/distributed/debug/_frontend.py new file mode 100644 index 0000000000000..622c41ca8bd64 --- /dev/null +++ b/torch/distributed/debug/_frontend.py @@ -0,0 +1,353 @@ +import json +import logging +import socket +import threading +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import parse_qs, urlparse + +import requests +from jinja2 import DictLoader, Environment + +from torch.distributed.debug._store import get_world_size, tcpstore_client + + +logger: logging.Logger = logging.getLogger(__name__) + + +def fetch_all( + endpoint: str, args: str = "" +) -> tuple[list[str], Iterator[requests.Response]]: + store = tcpstore_client() + keys = [f"rank{r}" for r in range(get_world_size())] + addrs = store.multi_get(keys) + addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] + + with ThreadPoolExecutor(max_workers=10) as executor: + resps = executor.map(requests.post, addrs) + + return addrs, resps + + +def format_json(blob: str): + parsed = json.loads(blob) + return json.dumps(parsed, indent=2) + + +templates = { + "base.html": """ + + + {% block title %}{% endblock %} - PyTorch Distributed + + + + + + + +
+ {% block header %}{% endblock %} + {% block content %}{% endblock %} +
+ """, + "index.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}Index{% endblock %}

+{% endblock %} +{% block content %} +Hi +{% endblock %} + """, + "raw_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{title}}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ resp.text }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "json_resp.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}{{ title }}{% endblock %}

+{% endblock %} +{% block content %} + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} +
{{ format_json(resp.text) }}
+ {% endif %} + {% endfor %} +{% endblock %} + """, + "profile.html": """ +{% extends "base.html" %} +{% block header %} +

{% block title %}torch.profiler{% endblock %}

+{% endblock %} + +{% block content %} +
+ + + +
+ + + + {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %} +

Rank {{ i }}: {{ addr }}

+ {% if resp.status_code != 200 %} +

Failed to fetch: status={{ resp.status_code }}

+
{{ resp.text }}
+ {% else %} + + + + {% endif %} + {% endfor %} +{% endblock %} + """, +} + + +class _IPv6HTTPServer(ThreadingHTTPServer): + address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore + request_queue_size: int = 1024 + + +class HTTPRequestHandler(BaseHTTPRequestHandler): + frontend: "FrontendServer" + + def do_GET(self): + self.frontend._handle_request(self) + + def get_path(self) -> str: + return urlparse(self.path).path + + def get_query(self) -> dict[str, list[str]]: + return parse_qs(urlparse(self.path).query) + + def get_query_arg( + self, name: str, default: object = None, type: type = str + ) -> object: + query = self.get_query() + if name not in query: + return default + return type(query[name][0]) + + +class FrontendServer: + def __init__(self, port: int): + # Setup templates + loader = DictLoader(templates) + self._jinja_env = Environment(loader=loader, enable_async=True) + self._jinja_env.globals.update( + zip=zip, + format_json=format_json, + enumerate=enumerate, + ) + + # Create routes + self._routes = { + "/": self._handle_index, + "/stacks": self._handle_stacks, + "/fr_trace": self._handle_fr_trace, + "/fr_trace_nccl": self._handle_fr_trace_nccl, + "/profile": self._handle_profiler, + } + + # Create HTTP server + RequestHandlerClass = type( + "HTTPRequestHandler", + (HTTPRequestHandler,), + {"frontend": self}, + ) + + server_address = ("", port) + self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) + + self._thread = threading.Thread( + target=self._serve, + args=(), + daemon=True, + ) + self._thread.start() + + def _serve(self) -> None: + try: + self._server.serve_forever() + except Exception: + logger.exception("got exception in checkpoint server") + + def join(self) -> None: + self._thread.join() + + def _handle_request(self, req: HTTPRequestHandler) -> None: + path = req.get_path() + if path not in self._routes: + req.send_error(404, f"Handler not found: {path}") + return + + handler = self._routes[path] + try: + resp = handler(req) + except Exception as e: + logger.exception( + "Exception in checkpoint server when handling %s", + path, + ) + req.send_error(500, str(e)) + return + + req.send_response(200) + req.send_header("Content-type", "text/html") + req.end_headers() + req.wfile.write(resp) + + def _render_template(self, template: str, **kwargs: object) -> bytes: + return self._jinja_env.get_template(template).render(**kwargs).encode() + + def _handle_index(self, req: HTTPRequestHandler) -> bytes: + return self._render_template("index.html") + + def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_traceback") + return self._render_template( + "raw_resp.html", title="Stacks", addrs=addrs, resps=resps + ) + + def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("fr_trace_json") + + return self._render_template( + "json_resp.html", + title="FlightRecorder", + addrs=addrs, + resps=resps, + ) + + def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: + addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") + + return self._render_template( + "json_resp.html", + title="FlightRecorder NCCL", + addrs=addrs, + resps=resps, + ) + + def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: + duration = req.get_query_arg("duration", default=1.0, type=float) + + addrs, resps = fetch_all("torch_profile", f"duration={duration}") + + return self._render_template("profile.html", addrs=addrs, resps=resps) + + +def main(port: int) -> None: + server = FrontendServer(port=port) + logger.info("Frontend server started on port %d", server._server.server_port) + server.join() diff --git a/torch/distributed/debug/_handlers.py b/torch/distributed/debug/_handlers.py new file mode 100644 index 0000000000000..ba951b7bda075 --- /dev/null +++ b/torch/distributed/debug/_handlers.py @@ -0,0 +1,22 @@ +import tempfile +import time + +from torch._C._distributed_c10d import _register_handler, _Request, _Response +from torch.profiler import _ExperimentalConfig, profile + + +def _torch_profile(req: _Request, resp: _Response) -> None: + experimental_config = _ExperimentalConfig( + profile_all_threads=True, + ) + duration = float(req.get_param("duration")) + with profile(record_shapes=True, experimental_config=experimental_config) as prof: + time.sleep(duration) + + with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f: + prof.export_chrome_trace(f.name) + resp.set_content(open(f.name, "rb").read(), "application/json") + resp.set_status(200) + + +_register_handler("torch_profile", _torch_profile) diff --git a/torch/distributed/debug/_store.py b/torch/distributed/debug/_store.py new file mode 100644 index 0000000000000..70c6cd0f3dde1 --- /dev/null +++ b/torch/distributed/debug/_store.py @@ -0,0 +1,24 @@ +import os + +import torch.distributed as dist + + +def get_rank() -> int: + return int(os.environ["RANK"]) + + +def get_world_size() -> int: + return int(os.environ["WORLD_SIZE"]) + + +def tcpstore_client() -> dist.Store: + MASTER_ADDR = os.environ["MASTER_ADDR"] + MASTER_PORT = int(os.environ["MASTER_PORT"]) + + store = dist.TCPStore( + host_name=MASTER_ADDR, + port=MASTER_PORT, + is_master=False, + ) + store = dist.PrefixStore("debug_server", store) + return store diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 794b755b1f64d..2bd7d24cd7d3f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -547,8 +547,12 @@ def foreach_reduce( op=reduce_scatter_op, ) else: - # For single GPU, just copy the input to output (no actual reduce-scatter needed) - reduce_output.copy_(reduce_scatter_input) + # For single GPU, just copy the input to output (no actual reduce-scatter needed), and + # account for a possible gradient_divide_factor. + if gradient_divide_factor is not None: + reduce_output.copy_(reduce_scatter_input / gradient_divide_factor) + else: + reduce_output.copy_(reduce_scatter_input) reduce_scatter_event = reduce_scatter_stream.record_event() post_reduce_stream = reduce_scatter_stream if all_reduce_group is not None: # HSDP or DDP/replicate @@ -721,20 +725,21 @@ def _get_gradient_divide_factors( if all_reduce_group is not None: data_parallel_size *= all_reduce_group.size() - if factor is None: - factor = float(data_parallel_size) - if not overflow_risk and not force_sum_reduction_for_comms: - if factor == data_parallel_size: + if factor is None: # Warning: NCCL ReduceOp.AVG may produce incorrect results with # world size 1. if data_parallel_size == 1: return None, None, ReduceOp.SUM, ReduceOp.SUM return None, None, ReduceOp.AVG, ReduceOp.AVG + if reduce_scatter_group is not None and factor == reduce_scatter_group.size(): + reduce_scatter_op = ReduceOp.AVG else: reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) - return None, None, reduce_scatter_op, ReduceOp.SUM + return None, None, reduce_scatter_op, ReduceOp.SUM + if factor is None: + factor = float(data_parallel_size) pre_factor: Optional[float] if overflow_risk: # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index cbd817a8bde37..630f327add3d7 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -135,7 +135,9 @@ def __init__(self) -> None: self._random_ops = { aten.native_dropout.default, aten.normal_.default, + aten.rand.default, aten.rand_like.default, + aten.randn.default, aten.randn_like.default, aten.randint_like.default, aten.randint_like.low_dtype, diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index f8325c83d55e4..42bf1ebeebf0e 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -101,6 +101,9 @@ def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: # DTensor no longer maintains a copy of rng state. manual seed on dtensor is the same thing # as manual seed on torch. + # + # torch.manual_seed will handle LocalTensor mode correctly by + # iterating through all ranks if seed is a LocalIntNode. torch.manual_seed(seed) @@ -239,6 +242,16 @@ def _set_device_state(self, state: torch.Tensor): def _distribute_region( self, spec: DTensorSpec, generator: Optional[torch.Generator] = None ): + from torch.distributed._local_tensor import maybe_enable_local_tracker + + if local_tracker_context := maybe_enable_local_tracker( + self._device.type, self.distribute_region_enabled, spec, generator + ): + with local_tracker_context: + yield + return + + # regular (non-LocalTensor) mode if generator is not None: # This is a little hacky, but for any user-passed generator, we store its state under a unique key, # not because we need to keep a copy of it but because its the easiest way to make it work with the diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 21930d81fe092..3ebcf6180d660 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -15,113 +15,105 @@ ) -def _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names +def _get_custom_obj_for_node(node, inputs_to_lifted_custom_objs, constants): + """Extract the custom object from a node's arguments.""" + custom_obj_node = node + custom_obj_meta = custom_obj_node.meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + + if custom_obj_meta.fake_val: + return custom_obj_meta.fake_val + elif custom_obj_node.name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + return constants[inputs_to_lifted_custom_objs[custom_obj_node.name]] # type: ignore[union-attr] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + + +def _replace_with_effects_node( + node, ep, inputs_to_lifted_custom_objs, output_tokens, input_tokens, module ): - inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs - - output_node = None - with_effect_nodes: list[torch.fx.Node] = [] - - # Output node need to check its args against output_token_names (collected from output_spec) - # Therefore, we only need to find the top-levele output node - output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) - for module in ep.graph_module.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue - - for node in module.graph.nodes: - if not (node.op == "call_function" and node.target is with_effects): - continue - - with_effect_nodes.append(node) - - # Remove tokens from outputs - assert output_node is not None - output_args = output_node.args[0] - assert len(output_args) >= num_tokens - out_token_nodes = output_args[:num_tokens] - output_node.args = (tuple(output_args[num_tokens:]),) - for out_token in out_token_nodes: - assert out_token.name in output_token_names - out_token.users.clear() - ep.graph.erase_node(out_token) - - # Replace with_effects(token, func, args) with just func(args) - for node in reversed(with_effect_nodes): - func = node.args[1] - assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) - - if func is torch.ops.higher_order.call_torchbind: - custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] - assert isinstance(custom_obj_meta, CustomObjArgument) - if custom_obj_meta.fake_val: - custom_obj = custom_obj_meta.fake_val - elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] - custom_obj = ep.constants[ - inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] - ] - else: - raise RuntimeError(f"Unable to find custom obj for node {node}") - schema = _get_schema(func, (custom_obj,) + node.args[3:]) - else: - schema = _get_schema(func, node.args[2:]) - - with ep.graph.inserting_before(node): - new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) - for k, v in node.meta.items(): - new_node.meta[k] = v - if k == "unbacked_bindings": - # Remove the extra layer for effect token - old_bindings = new_node.meta[k] - new_bindings = { - k: path[1:] if path else path for k, path in old_bindings.items() - } - new_node.meta[k] = new_bindings - - node.replace_all_uses_with(new_node) - - # Update user getitem nodes - for user in list(new_node.users.keys()): - assert user.target is operator.getitem - # getitem(with_effects, 0) == token - if user.args[1] == 0: - ep.graph.erase_node(user) - - if len(schema.returns) == 1: - # If the function has 1 return then it will just directly return the - # result -- we don't need a getitem. So we can replace all the - # getitem(with_effects, 1) with just the note itself. - for user in list(new_node.users.keys()): - assert user.args[1] == 1 + """Replace a with_effects node with the underlying function call.""" + # Get the input nodes + token_node, func, *node_args = node.args + if token_node.op == "placeholder": + input_tokens.append(token_node) + + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + # Get the schema for the function + if func is torch.ops.higher_order.call_torchbind: + custom_obj = _get_custom_obj_for_node( + node_args[0], inputs_to_lifted_custom_objs, ep.constants + ) + schema = _get_schema(func, [custom_obj] + node_args[1:]) + else: + schema = _get_schema(func, node_args) + + # Create the replacement node + with module.graph.inserting_before(node): + new_node = module.graph.call_function(func, tuple(node_args), node.kwargs) + + # Update getitem nodes that extract outputs from with_effects + for user in list(node.users.keys()): + assert user.target is operator.getitem + # getitem(with_effects, 0) is the token node + if user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) + + # Fix up the getitem nodes based on return count + if len(schema.returns) == 1: + # Single return: replace getitem(with_effects, 1) with the node itself + for user in list(node.users.keys()): + if user.args[1] == 1: user.replace_all_uses_with(new_node) - - new_node.meta["val"] = node.meta["val"][1] - elif len(schema.returns) > 1: - # If the function has more than 1 return then since we got rid of - # the 1st return value (the token), we need to bump all the other - # getitem calls by 1 down - for user in list(new_node.users.keys()): - assert user.args[1] >= 1 - user.args = (user.args[0], user.args[1] - 1) - - new_node.meta["val"] = node.meta["val"][1:] - else: - assert len(schema.returns) == 0 - assert len(new_node.users) == 0 - new_node.meta["val"] = None - - ep.graph.erase_node(node) - - # Remove tokens from inputs - placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] - assert len(placeholders) >= num_tokens - inp_token_nodes = placeholders[:num_tokens] - for inp_token in inp_token_nodes: - assert inp_token.name in input_token_names - ep.graph.erase_node(inp_token) - - ep.graph.eliminate_dead_code() + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # Multiple returns: shift getitem indices down by 1 + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (new_node, user.args[1] - 1) + new_node.meta["val"] = node.meta["val"][1:] + else: + # No returns + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + # Copy metadata from old node to new node + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + +def _replace_invoke_subgraph_node(node, module, output_tokens, input_tokens): + """Replace an invoke_subgraph node to remove the token argument.""" + assert node.args[0].op == "get_attr" + submod = getattr(module, node.args[0].target) + if not submod.meta.get("has_with_effects", False): + return + + # Remove token from inputs + subgraph, identifier, token, *operands = node.args + node.args = (subgraph, identifier, *operands) + if token.op == "placeholder": + input_tokens.append(token) + + # Update getitem nodes to account for removed token output + for user in list(node.users.keys()): + if user.args[1] >= 1: + user.args = (node, user.args[1] - 1) + elif user.args[1] == 0: + for user_user in list(user.users.keys()): + if user_user.op == "output": + output_tokens.append(user) def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: @@ -132,6 +124,65 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: This function does an inplace modification on the given ExportedProgram. """ + print("before", ep) + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + # mark submodules with effects as having effects. This will be used in the following pass to remove effects from subgraphs + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + with_effect_nodes = [ + node for node in module.graph.nodes if node.target is with_effects + ] + if len(with_effect_nodes) > 0: + module.meta["has_with_effects"] = True + + # Process each module with the replace hook to ensure graph signature is updated + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for _, module in ep.graph_module.named_modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + input_tokens = [] + output_tokens = [] + + # Process with_effects and invoke_subgraph nodes + for node in module.graph.nodes: + if node.target is with_effects: + _replace_with_effects_node( + node, + ep, + inputs_to_lifted_custom_objs, + output_tokens, + input_tokens, + module, + ) + elif node.target is torch.ops.higher_order.invoke_subgraph: + _replace_invoke_subgraph_node( + node, module, output_tokens, input_tokens + ) + + # Remove tokens from the output node + if len(output_tokens) > 0: + output_node = next(reversed(module.graph.find_nodes(op="output"))) + output_args = output_node.args[0] + assert len(output_args) >= len(output_tokens), ( + f"{output_args} output arguments found\n" + f"{output_tokens} output tokens found\n" + f"{module.graph}" + ) + output_node.args = (tuple(output_args[len(output_tokens) :]),) + + module.graph.eliminate_dead_code() + + # Remove tokens from the input placeholders + for node in module.graph.nodes: + if node.op == "placeholder" and node in input_tokens: + module.graph.erase_node(node) + + module.recompile() + num_tokens: int = 0 input_token_names: list[str] = [] new_input_specs: list[InputSpec] = [] @@ -159,9 +210,5 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: assert num_tokens == num_out_tokens - with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): - _remove_effect_tokens_from_graph_helper( - ep, num_tokens, input_token_names, output_token_names - ) - + print("after", ep) return ep diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 52d06a294fac1..6239c5899c233 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -748,11 +748,23 @@ def _unlift_exported_program_lifted_states( ) -> torch.fx.GraphModule: check_guards = check_guards and _ok_to_generate_guards_fn() + source_node_dict = { + node.name: node for node in ep.graph.nodes if node.op != "placeholder" + } + # placeholder node name might change after deepcopy + placeholder_source_node_dict = { + node.target: node for node in ep.graph.nodes if node.op == "placeholder" + } + + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + new_gm.meta.update(ep.graph_module.meta) + ep = copy.copy(ep) + ep._graph_module = new_gm + # TODO T206340015 if ep.verifiers[0].dialect != "TRAINING": ep = _remove_effect_tokens(ep) - new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) forward_arg_names = ( sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None @@ -786,19 +798,13 @@ def _unlift_exported_program_lifted_states( for out_spec in ep.graph_signature.output_specs ] - source_node_dict = { - node.name: node for node in ep.graph.nodes if node.op != "placeholder" - } - # placeholder node name might change after deepcopy - placeholder_source_node_dict = { - node.target: node for node in ep.graph.nodes if node.op == "placeholder" - } for node in new_gm.graph.nodes: source_node = None if node.op == "placeholder": source_node = placeholder_source_node_dict.get(node.target) else: - source_node = source_node_dict.get(node.name) + if node.name in source_node_dict: + source_node = source_node_dict.get(node.name) node.meta["from_node"] = [ NodeSource( source_node, diff --git a/torch/fx/node.py b/torch/fx/node.py index 294e15c550235..cb37b6ece75dd 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -753,7 +753,9 @@ def is_impure(self, impure_random: bool = True) -> bool: # between eager and compiled execution, regardless of generator usage return True - return self.target in _side_effectful_functions + from torch._higher_order_ops.effects import has_effects + + return self.target in _side_effectful_functions or has_effects(self.target) # Check if an impure module. if self.op == "call_module": diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index f3400e438a2d3..056a5fcc21fdd 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -9,7 +9,7 @@ from enum import Enum from functools import partial from typing import Any, Optional -from typing_extensions import Self +from typing_extensions import deprecated, Self from warnings import warn import torch @@ -408,6 +408,11 @@ def _memory_profile(self) -> MemoryProfile: ) return MemoryProfile(self.profiler.kineto_results) + @deprecated( + "`export_memory_timeline` is deprecated and will be removed in a future version. " + "Please use `torch.cuda.memory._record_memory_history` and `torch.cuda.memory._export_memory_snapshot` instead.", + category=FutureWarning, + ) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: """Export memory event information from the profiler collected tree for a given device, and export a timeline plot. There are 3 @@ -429,6 +434,11 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non ``torch.profiler._memory_profiler.Category``. Output: Memory timeline written as gzipped JSON, JSON, or HTML. + + .. deprecated:: + ``export_memory_timeline`` is deprecated and will be removed in a future version. + Please use ``torch.cuda.memory._record_memory_history`` and + ``torch.cuda.memory._export_memory_snapshot`` instead. """ # Default to device 0, if unset. Fallback on cpu. if device is None: diff --git a/torch/random.py b/torch/random.py index cf23e52db320e..f86d7349019dc 100644 --- a/torch/random.py +++ b/torch/random.py @@ -39,6 +39,10 @@ def manual_seed(seed) -> torch._C.Generator: is raised. Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. """ + return _manual_seed_impl(seed, update_local_tensor_states=True) + + +def _manual_seed_impl(seed, update_local_tensor_states) -> torch._C.Generator: seed = int(seed) import torch.cuda diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 6ce7d4b2ca507..1f6c4aece1e80 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -386,7 +386,7 @@ def device_type(self) -> str: @property def backend(self) -> str: - backend = dist.get_default_backend_for_device(DEVICE_TYPE) + backend = dist.get_default_backend_for_device(self.device_type) return backend def init_manual_seed_for_rank(self) -> None: @@ -724,6 +724,9 @@ def setUp(self) -> None: torch.autograd._enable_record_function(False) def tearDown(self) -> None: + from torch.distributed.tensor import _random as random + + random._rng_tracker = None super().tearDown() torch.autograd._enable_record_function(True) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 16877719718af..3d2e4d110b6b2 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1113,11 +1113,8 @@ def __post_init__(self) -> None: num_leaves = 1 num_children = 0 else: - num_nodes = 1 - num_leaves = 0 - for child in self._children: - num_nodes += child.num_nodes - num_leaves += child.num_leaves + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) num_children = len(self._children) object.__setattr__(self, "num_nodes", num_nodes) object.__setattr__(self, "num_leaves", num_leaves) diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 8abb547d500f8..df4bf34db2114 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -249,6 +249,8 @@ def format_sequence(obj): if len(filename) > FRAME_FILENAME_LIMIT: filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):] return f"frame\n{filename}:{obj.f_lineno}" + elif is_cuda_tensor(obj): + return f"object\n{type(obj).__module__}.{type(obj).__name__} ({obj.shape})" else: return f"object\n{type(obj).__module__}.{type(obj).__name__}"