From 56dfd4c74bfa2fb03fbfeb28f7774870248910c4 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Mon, 3 Nov 2025 22:12:22 +0000 Subject: [PATCH 001/651] Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526) Summary: * Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA * Add testing Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526 Approved by: https://github.com/drisspg, https://github.com/ngimel --- aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/native/cuda/ScaledBlas.cpp | 81 +++++++++++++++++++++--- test/test_scaled_matmul_cuda.py | 40 ++++++------ 3 files changed, 91 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 8b283c417b74b..ae762e1def3ec 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI) if(USE_CUDA) # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. - set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*") + set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*") file(GLOB_RECURSE fbgemm_genai_native_cuda_cu "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 0d2963874abbd..9065d79929360 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -59,6 +59,24 @@ // forward declare class cublasCommonArgs; +#ifndef _WIN32 +namespace fbgemm_gpu { + +// NOTE(slayton58): FBGemm_GPU kernels come from within the FBGemm repo. +// To update supported ops means a submodule bump, which is.. painful. Instead, we +// can simply forward-declare the methods we want to use.. Works at least as a short-term +// thing, but should still be fixed somewhere/somehow. +at::Tensor f4f4bf16( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + std::optional, + bool use_mx); + +} // namespace fbgemm_gpu +#endif + using at::blas::ScalingType; using at::blas::SwizzleType; @@ -1087,26 +1105,47 @@ _scaled_mxfp4_mxfp4( const std::optional& bias, const c10::ScalarType out_dtype, Tensor& out) { -#ifndef USE_ROCM - TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); -#endif +#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)) + TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only"); +#else // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); - auto scale_a_elems = ceil_div(2 * mat_a.size(0), 32) * mat_a.size(1); - auto scale_b_elems = ceil_div(2 * mat_b.size(1), 32) * mat_b.size(0); + // Packed FP4 format means actual-K = 2 * reported-K -- adjust + auto K_multiplier = 2; +#ifdef USE_ROCM + // AMD + auto scale_a_elems = ceil_div(K_multiplier * mat_a.size(0), 32) * mat_a.size(1); + auto scale_b_elems = ceil_div(K_multiplier * mat_b.size(1), 32) * mat_b.size(0); +#else + // NVIDIA + auto scale_a_elems = round_up(mat_a.size(0), 128) * round_up(ceil_div(K_multiplier * mat_a.size(1), 32), 4); + auto scale_b_elems = round_up(mat_b.size(1), 128) * round_up(ceil_div(K_multiplier * mat_b.size(0), 32), 4); +#endif TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); +#ifdef USE_ROCM + // AMD + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)"); +#else + // NVIDIA + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); +#endif + TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), "For Blockwise scaling both scales should be contiguous"); TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); +#ifdef USE_ROCM + // AMD auto scaling_choice_a = ScalingType::BlockWise1x32; auto scaling_choice_b = ScalingType::BlockWise1x32; @@ -1121,11 +1160,30 @@ _scaled_mxfp4_mxfp4( TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, "Block-wise scaling only supports BFloat16 or Half output types"); -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); #endif return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); +#else + // NVIDIA + // NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor, + // but we have one we need to use. Two clear options are to copy into + // our output (slow), or use a move-assignment-operator (faster). + // However, the compiler can complain about the explicit move preventing + // copy elision because the return from f4f4bf16 is a temporary object. + // So we don't explicitly move, and trust the compiler here... + // In the longer term this should be fixed on the FBGemm side. + out = fbgemm_gpu::f4f4bf16( + mat_a, + mat_b.transpose(-2, -1), + scale_a, + scale_b, + std::nullopt, /* global_scale */ + true /* use_mx */ + ); + + return out; +#endif +#endif } Tensor& @@ -1250,17 +1308,20 @@ _scaled_mm_cuda_v2_out( mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")"); } + // Handle fp4 packed-K dimension + int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1; + TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1], " but got ", bias->numel()); TORCH_CHECK_VALUE( - mat_a.sizes()[1] % 16 == 0, + K_multiplier * mat_a.sizes()[1] % 16 == 0, "Expected trailing dimension of mat1 to be divisible by 16 ", "but got mat1 shape: (", mat_a.sizes()[0], "x", - mat_a.sizes()[1], + K_multiplier * mat_a.sizes()[1], ")."); - TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", + TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", mat_b.sizes()[1], ") must be divisible by 16"); // TODO(slayton): Existing checks, not sure if they should really be here. diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 4d88ccd9cc7dd..9738ac4ac6fbf 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -209,42 +209,36 @@ def infer_scale_swizzle(mat, scale): ] == math.ceil(mat.shape[1] // 128): return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE + # if we're checking for nvfp4, need to adjust for packed-K + K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1 # NVFP4 if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4)) and mat.dtype == torch.float4_e2m1fn_x2 and scale.dtype == torch.float8_e4m3fn ): return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4 - # MXFP4 w/o swizzle - if ( - (scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]) - and mat.dtype == torch.float4_e2m1fn_x2 - and scale.dtype == torch.float8_e8m0fnu - ): - return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE - + # MX formats if not torch.version.hip: - # MXFP8 w/ swizzle + # MX w/swizzle (NVIDIA) if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4)) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4 else: - # MXFP8 w/o swizzle + # MX w/o swizzle (AMD) if ( - (scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]) + (scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1] + or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0]) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE @@ -1868,7 +1862,7 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") - @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) + @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") @@ -1882,8 +1876,12 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0): raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping") - fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn - BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) + fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn + BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 + + if K % BLOCK_SIZE != 0: + raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping") + require_exact_match = True approx_match_sqnr_target = 22.0 @@ -2061,7 +2059,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 15 if torch.version.hip else 15.8 + approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8 C_ref = A_ref @ B_ref.t() From afd50bdd290d1ff8976d8477efb9ad9256705d88 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 4 Nov 2025 16:43:06 +0000 Subject: [PATCH 002/651] [CI] Use smaller amx + avx2 runners for inductor test? (#164989) Results from CI: No failures but generally takes longer, maybe ~20% increase in time? But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989 Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn --- .github/workflows/inductor-unittest.yml | 8 ++++---- .github/workflows/inductor.yml | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 6ab276a57fc4d..3ce917567aec2 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -115,10 +115,10 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, ]} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 2616141c0dc2a..8a913c3b36a11 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -84,13 +84,13 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} build-additional-packages: "vision audio torchao" From 8d4b8ab43033667f66a1180974d8faf9b1b8b93d Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 4 Nov 2025 16:45:22 +0000 Subject: [PATCH 003/651] [ez] Print some more test timing info in the logs (#166447) You can just subtract timestamps, but this makes it easier Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447 Approved by: https://github.com/Skylion007 --- test/run_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index 4b7030d461529..448fbc28751f3 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1826,9 +1826,14 @@ def run_test_module( test_name = test.name # Printing the date here can help diagnose which tests are slow - print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") + start = time.perf_counter() + print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]") handler = CUSTOM_HANDLERS.get(test_name, run_test) return_code = handler(test, test_directory, options) + end = time.perf_counter() + print_to_stderr( + f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min" + ) assert isinstance(return_code, int) and not isinstance(return_code, bool), ( f"While running {str(test)} got non integer return code {return_code}" ) From 68eb55c4b23babd005267dfd322dc4b070041f58 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 3 Nov 2025 17:40:23 -0800 Subject: [PATCH 004/651] Add model code stack trace to cuda.memory._snapshot (#166676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information. To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be ``. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`. `augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames. The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`. Alternative: Instead of setting co_filename, we can also do it like below: Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available. ``` key = filename globals_copy = globals.copy() globals_copy["__file__"] = key globals_copy["__name__"] = key linecache.lazycache(key, globals_copy) exec(compile(src, key, "exec"), globals) ```` Other changes: - Update `MemoryViz.js` to display fx node information and original model code if exist ``` python test/test_fx.py -k test_lineno_map python test/test_fx.py -k test_custom_traceback_raised python test/test_public_bindings.py python test/test_cuda.py -k test_fx_memory python test/test_fx.py -k test_informative_co_filename python test/test_fx.py -k test_autowrap_functions python test/dynamo/test_utils.py -k test_inductor_provenance ``` ```python # Profile with memory snapshot torch.cuda.memory._record_memory_history() with torch._dynamo.config.patch("enrich_profiler_stack_trace", True): compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) result = compiled(torch.randn(10, 10, device="cuda:0")) torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True) torch.cuda.memory._record_memory_history(enabled=None) ``` Screenshot 2025-10-30 at 10 40 44 AM Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676 Approved by: https://github.com/albanD, https://github.com/ezyang --- test/test_cuda.py | 134 +++++++++++++++++++++++ test/test_fx.py | 2 + torch/_dynamo/config.py | 6 ++ torch/cuda/memory.py | 202 +++++++++++++++++++++++++++++++++-- torch/fx/graph.py | 13 ++- torch/fx/graph_module.py | 79 +++++++++++++- torch/fx/traceback.py | 22 ++++ torch/utils/viz/MemoryViz.js | 24 ++++- 8 files changed, 472 insertions(+), 10 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index a7e373da63824..00c3b00d6049c 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -7413,6 +7413,140 @@ def test_graph_external_wait_and_record(self): ) +class TestFXMemoryProfiler(TestCase): + """Tests for memory profiler augmentation with original stack traces.""" + + def collect_frames( + self, augmented_snapshot, collect_device_traces=True, collect_segments=True + ): + """Collects all frames that has node metadata from a memory snapshot.""" + # Collect all frames with FX metadata + fx_frames = [] + + # Check device traces for FX debug fields + if collect_device_traces and "device_traces" in augmented_snapshot: + for trace_list in augmented_snapshot["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + for frame in trace_entry["frames"]: + if isinstance(frame, dict): + # Check for FX debug fields + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + + # Check segments/blocks for FX debug fields + if collect_segments and "segments" in augmented_snapshot: + for segment in augmented_snapshot["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + for frame in block["frames"]: + if isinstance(frame, dict): + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + return fx_frames + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_fx_memory_profiler_augmentation(self): + """Test that memory snapshots are augmented with FX debug information.""" + + # Create a simple model + class MLPModule(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + a = self.net1(x) + b = self.relu(a) + c = self.net2(b) + return c + + device = "cuda" + mod = MLPModule(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + torch.cuda.empty_cache() + + fx_frames = self.collect_frames(augmented_snapshot) + if TEST_WITH_ROCM: + self.assertGreater(len(fx_frames), 0) + else: + self.assertEqual(len(fx_frames), 12) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("a = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("c = self.net2(b)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("b = self.relu(a)", frame["fx_original_trace"]) + + # Test that when we have two graphs with the same src_code, they're not hashed + # to the same metadata + class MLPModule2(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + d = self.net1(x) + e = self.relu(d) + f = self.net2(e) + return f + + mod = MLPModule2(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + + # avoid collecting segments from previous run for unit test purpose + fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False) + self.assertGreater(len(fx_frames), 0) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("d = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("f = self.net2(e)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_parametrized_tests(TestCompileKernel) diff --git a/test/test_fx.py b/test/test_fx.py index 880cc91edc067..d6f33d426aee7 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -771,6 +771,7 @@ def forward(self, a, b): gm = GraphModule(tracer.root, graph) expected = {1: 2, 2: 3, 3: 4, 4: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) # test custom codegen def transform_code(code): @@ -780,6 +781,7 @@ def transform_code(code): gm.recompile() expected = {2: 2, 3: 3, 4: 4, 5: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) def test_graph_unique_names_manual(self): graph: torch.fx.Graph = torch.fx.Graph() diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5858a4584b3dd..0c95408401c79 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -739,6 +739,12 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 2dfd5f9479499..6834ffb5706a0 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -4,12 +4,14 @@ import collections import contextlib import ctypes +import os import pickle +import re import sys import warnings from inspect import signature -from typing import Any, Literal, Optional, TYPE_CHECKING -from typing_extensions import deprecated +from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict +from typing_extensions import deprecated, NotRequired import torch from torch import _C @@ -29,6 +31,60 @@ from torch.types import Device +# Type definitions for memory profiler +class _Frame(TypedDict): + """Frame information from memory profiler snapshots.""" + + filename: str + line: int + name: str + # Fields added by FX augmentation (optional) + fx_node_op: NotRequired[str] + fx_node_name: NotRequired[str] + fx_node_target: NotRequired[str] + fx_original_trace: NotRequired[str] + + +class _Block(TypedDict): + """Memory block information.""" + + size: int + requested_size: int + address: int + state: str + frames: list[_Frame] + + +class _Segment(TypedDict): + """Memory segment information.""" + + address: int + total_size: int + stream: int + segment_type: str + allocated_size: int + active_size: int + blocks: list[_Block] + + +class _TraceEntry(TypedDict): + """Memory trace entry information.""" + + action: str + addr: NotRequired[int] + frames: list[_Frame] + size: int + stream: int + device_free: NotRequired[int] + + +class _Snapshot(TypedDict): + """Memory snapshot structure.""" + + segments: list[_Segment] + device_traces: NotRequired[list[list[_TraceEntry]]] + + __all__ = [ "caching_allocator_alloc", "caching_allocator_delete", @@ -964,7 +1020,120 @@ def _record_memory_history_impl( _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] -def _snapshot(device: "Device" = None): +def _augment_frames(frames: list[_Frame]) -> int: + """ + Augment a list of frames with FX debug information. + + Args: + frames: List of frame dictionaries to augment + + Returns: + The count of frames that were augmented. + """ + from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX + + # Regex pattern to match FX generated files + _FX_GENERATED_PATTERN = re.compile( + rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$" + ) + + count = 0 + if not frames: + return count + + for frame in frames: + if "filename" in frame and "line" in frame: + filename = frame["filename"] + lineno = frame["line"] + + # Check if this looks like an FX generated file + if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)): + continue + + # Look up metadata from the global registry + from torch.fx.traceback import _FX_METADATA_REGISTRY + + metadata = _FX_METADATA_REGISTRY.get(filename) + if metadata is None: + continue + + lineno_map = metadata.get("lineno_map", {}) + node_metadata = metadata.get("node_metadata", {}) + prologue_start = metadata.get("prologue_start", 0) + + # Get the node index for this line + node_idx = lineno_map.get(lineno - prologue_start) + + if node_idx is not None and node_idx in node_metadata: + node_info = node_metadata[node_idx] + original_trace = node_info.get("stack_trace") + node_op = node_info.get("op") + node_name = node_info.get("name") + node_target = node_info.get("target") + + # Always add node metadata + frame["fx_node_op"] = node_op + frame["fx_node_name"] = node_name + frame["fx_node_target"] = str(node_target) + + # Add original trace if available + if original_trace: + frame["fx_original_trace"] = original_trace + + count += 1 + + return count + + +def _augment_memory_snapshot_stack_traces( + snapshot: str | _Snapshot, +) -> _Snapshot: + """ + Augment a memory snapshot with original source stack traces from FX metadata. + + IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY) + that is populated during graph module compilation. It must be called in the same + Python process where the FX graphs were compiled. It cannot be used to augment + snapshots loaded from disk in a different process. + + Args: + snapshot: Either a memory snapshot dict or path to a snapshot pickle file + + Returns: + The augmented snapshot dictionary with fx_node_op, fx_node_name, + fx_original_trace, and fx_node_info fields added to frames + """ + + snapshot_dict: _Snapshot + if isinstance(snapshot, str): + # Load the memory snapshot + with open(snapshot, "rb") as f: + snapshot_dict = cast(_Snapshot, pickle.load(f)) + else: + snapshot_dict = snapshot + + # Process stack traces in the snapshot + augmented_count = 0 + + # Process blocks in segments (for regular allocations) + if "segments" in snapshot_dict: + for segment in snapshot_dict["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + augmented_count += _augment_frames(block["frames"]) + + # Process device traces (for memory history) + if "device_traces" in snapshot_dict: + for trace_list in snapshot_dict["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + augmented_count += _augment_frames(trace_entry["frames"]) + + return snapshot_dict + + +def _snapshot(device: "Device" = None, augment_with_fx_traces=False): """Save a snapshot of CUDA memory state at the time it was called. The state is represented as a dictionary with the following structure. @@ -1012,6 +1181,11 @@ class Frame(TypedDict): filename: str line: int name: str + # Optional FX debug fields (present when augment_with_fx_traces=True + # and the frame corresponds to FX-generated code) + fx_node_op: str # FX node operation type (e.g., 'call_function', 'output') + fx_node_name: str # FX node name (e.g., 'linear', 'relu_1') + fx_original_trace: str # Original model source code stack trace class TraceEntry(TypedDict): @@ -1041,13 +1215,23 @@ class TraceEntry(TypedDict): device_free: int # only present for OOM, the amount of # memory cuda still reports to be free + Args: + device: Device to capture snapshot for. If None, captures for current device. + augment_with_fx_traces: If True, augment stack trace frames with FX debug information + that maps generated FX code back to original model source code. + This adds fx_node_op, fx_node_name, fx_original_trace, and + fx_node_info fields to Frame objects. Default: False. + Returns: The Snapshot dictionary object """ - return _C._cuda_memorySnapshot(None) + s = _C._cuda_memorySnapshot(None) + if augment_with_fx_traces: + s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type] + return s -def _dump_snapshot(filename="dump_snapshot.pickle"): +def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False): """ Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. @@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"): Args: filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information + before dumping. This maps generated FX code stack traces + back to original model source code. Defaults to False. + verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output + during augmentation. Defaults to False. """ - s = _snapshot() + s = _snapshot(augment_with_fx_traces=augment_with_fx_traces) + with open(filename, "wb") as f: pickle.dump(s, f) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fc6f4c5b27021..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -226,8 +226,10 @@ class PythonCode: # Values in global scope during execution of `src_def`. globals: dict[str, Any] # Optional mapping from the forward function's line number to - # node index. + # node index. Line number starts at the prologue (i.e. forward()). _lineno_map: Optional[dict[int, Optional[int]]] + # The line number of prologue in fn_code + _prologue_start: int = 0 def _format_target(base: str, target: str) -> str: @@ -854,7 +856,14 @@ def _tensor_annotation(t: torch.Tensor) -> str: {prologue} {code}""" - return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + # The +4 accounts for the empty lines before prologue in fn_code + prologue_start = wrap_stmts.count("\n") + 4 + return PythonCode( + fn_code, + globals_, + _lineno_map=lineno_map, + _prologue_start=prologue_start, + ) # Ideally, we'd like to refactor all of the pytree logic into this codegen diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 159926bc8ba49..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +import base64 import contextlib import copy +import hashlib import itertools import linecache import os @@ -36,6 +38,7 @@ ] _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" +FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_" # Normal exec loses the source code, however we can work with @@ -61,7 +64,13 @@ def cache(self, src: str, globals: dict[str, Any], co_fields=None): key = self._get_key() if co_fields: - key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + if "co_filename" in co_fields: + # If only co_filename is provided, use it directly as the key + if "co_firstlineno" not in co_fields or "co_name" not in co_fields: + key = co_fields["co_filename"] + else: + # Full co_fields with all three components + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" self.eval_cache[key] = src # Don't mutate globals so that this loader is only used @@ -353,6 +362,36 @@ def _print_readable( return output +def _metadata_hash(code: str, node_metadata: dict) -> str: + """ + Create a content-addressed hash from code and metadata. + + Args: + code: The source code string + lineno_map: Mapping from line numbers to node indices + node_metadata: Metadata for each node + + Returns: + A 51-character base32-encoded hash + """ + import json + + # Create a deterministic string representation of all components + # We use JSON to ensure consistent serialization + hash_data = { + "code": code, + "node_metadata": node_metadata, + } + hashing_str = json.dumps(hash_data).encode("utf-8") + + # [:51] to strip off the "Q====" suffix common to every hash value. + return ( + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51] + .decode("utf-8") + .lower() + ) + + class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls @@ -825,9 +864,47 @@ def recompile(self) -> PythonCode: python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map + self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config + + if dynamo_config.enrich_profiler_metadata: + # Generate metadata and register for profiler augmentation + node_metadata: dict[int, dict[str, Any]] = {} + for i, node in enumerate(self._graph.nodes): + node_metadata[i] = { + "name": node.name, + "op": node.op, + "target": str(node.target), + "stack_trace": node.meta.get("stack_trace", None), + } + + # Generate a content-addressed filename based on hash of code and metadata + # This ensures the same code+metadata always generates the same filename + hash_value = _metadata_hash(self._code, node_metadata) + file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + + filename = f"{file_stem}.py" + + # Only include co_filename to use it directly as the cache key + co_fields = { + "co_filename": filename, + } + + # Store metadata in global in-memory registry + metadata = { + "lineno_map": python_code._lineno_map, + "prologue_start": python_code._prologue_start, + "node_metadata": node_metadata, + } + + # Register metadata in the global registry + from torch.fx.traceback import _register_fx_metadata + + _register_fx_metadata(filename, metadata) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a143119cd78b0..25fb81a5aa016 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -38,6 +38,28 @@ current_replay_node: Optional[Node] = None should_preserve_node_meta = False +# ============================================================================= +# FX Metadata Registry for Memory Profiler +# ============================================================================= +# Global in-memory registry for FX metadata +# Maps module_name -> metadata dict containing lineno_map and node_metadata +_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {} + + +def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None: + """ + Register FX metadata in the global in-memory registry. + + This is called automatically during graph module compilation to store metadata + for later use by memory profiler augmentation. + + Args: + module_name: The module identifier (content-addressed filename) + metadata: Metadata dict containing lineno_map, node_metadata, and source_code + """ + # TODO: add logging to tlparse + _FX_METADATA_REGISTRY[module_name] = metadata + @compatibility(is_backward_compatible=False) class NodeSourceAction(Enum): diff --git a/torch/utils/viz/MemoryViz.js b/torch/utils/viz/MemoryViz.js index 09f8c444f600c..dfeae36cebab7 100644 --- a/torch/utils/viz/MemoryViz.js +++ b/torch/utils/viz/MemoryViz.js @@ -806,7 +806,29 @@ function format_frames(frames) { } const frame_strings = frames .filter(frameFilter) - .map(f => `${f.filename}:${f.line}:${f.name}`); + .map(f => { + let frame_str = `${f.filename}:${f.line}:${f.name}`; + + // Add FX debug information if available + if (f.fx_node_op || f.fx_node_name || f.fx_node_target) { + const fx_parts = []; + if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`); + if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`); + if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`); + frame_str += `\n >> FX: ${fx_parts.join(', ')}`; + } + + if (f.fx_original_trace) { + frame_str += `\n >> Original Model Code:`; + const original_lines = f.fx_original_trace.trim().split('\n'); + // Show all lines of the original trace + for (const line of original_lines) { + frame_str += `\n ${line}`; + } + } + + return frame_str; + }); return elideRepeats(frame_strings).join('\n'); } From d02f68f4840af4ff2431a3015ff8d64aea43e720 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 3 Nov 2025 10:27:22 -0800 Subject: [PATCH 005/651] [BE] Use `[[maybe_unused]]` (#166865) Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature Will replace further repetitions of the same pattern soon after Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865 Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007, https://github.com/janeyx99 --- torch/csrc/stable/stableivalue_conversions.h | 48 +++++++------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index f35ed50d99be4..8004e91b77f8e 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -31,10 +31,8 @@ template struct FromImpl { static StableIValue call( T val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -75,10 +73,8 @@ template <> struct FromImpl { static StableIValue call( ScalarType val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { switch (val) { case ScalarType::Byte: return from(aoti_torch_dtype_uint8()); @@ -133,10 +129,8 @@ template <> struct FromImpl { static StableIValue call( std::nullopt_t val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return from(nullptr); } }; @@ -190,10 +184,8 @@ template <> struct FromImpl { static StableIValue call( const torch::stable::Tensor& val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { AtenTensorHandle new_ath; TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); return from(new_ath); @@ -209,10 +201,8 @@ template struct ToImpl { static T call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the @@ -249,10 +239,8 @@ template <> struct ToImpl { static ScalarType call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { int32_t shim_scalartype = to(val); if (shim_scalartype == aoti_torch_dtype_uint8()) { return ScalarType::Byte; @@ -309,10 +297,8 @@ template <> struct ToImpl { static std::nullopt_t call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { // val should be equivalent to from(nullptr) return std::nullopt; } @@ -350,10 +336,8 @@ template <> struct ToImpl { static torch::stable::Tensor call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return torch::stable::Tensor(to(val)); } }; From eefa16342c9f322b56c7c0cd6d309c3ed8f0b882 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Tue, 4 Nov 2025 12:59:31 +0000 Subject: [PATCH 006/651] [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165) Prefer unfused addmm when there is at least a single elemwise/reduction consumer.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165 Approved by: https://github.com/eellison --- test/inductor/test_padding.py | 7 ++- test/inductor/test_torchinductor.py | 4 +- torch/_inductor/fx_passes/post_grad.py | 8 ++-- torch/_inductor/utils.py | 64 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c67bde87a369b..5e599110d29d6 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -500,8 +500,13 @@ def test_LinearAndSoftmax_codegen(self, bias=True): forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned + if bias: + # addmm -> mm + bias and bias is fused with softmax + softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)" + else: + softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)" self.assertTrue( - "tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper, + softmax_load_str in forward_wrapper, f"forward_wrapper: {forward_wrapper}", ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 675d912c0c01f..dad2de9bde327 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15280,7 +15280,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_native_layer_norm_relu", + "triton_poi_fused_addmm_native_layer_norm", (torch.randn(4, 4, device=GPU_TYPE),), ), ] @@ -15293,7 +15293,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_ReLU", + "triton_poi_fused_LayerNorm_Linear_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index f11817e1d4c51..7d995adec04ef 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -51,8 +51,8 @@ decode_device, get_all_devices, get_gpu_type, + has_uses_tagged_as, is_gpu, - is_pointwise_use, OPTIMUS_EXCLUDE_POST_GRAD, ) from ..virtualized import V @@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match): if not is_gpu(inp.meta["val"].device.type): return False - output = match.output_node() - return all(is_pointwise_use(use) for use in output.users) + return has_uses_tagged_as( + match.output_node(), + (torch.Tag.pointwise, torch.Tag.reduction), + ) @register_graph_pattern( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 13938f6ec1e55..6b34ef28b2c10 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -549,6 +549,70 @@ def is_pointwise_use( return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) +class LogicalConnective(enum.Enum): + OR = enum.auto() + AND = enum.auto() + + +def has_uses( + target: Node, + use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Given a target, explore the uses of `target` by applying `use_selector_fn` + on them, and then aggregate these booleans with the `use_aggregate_type` + logical connective. + + Uses in view ops will follow the views uses. + """ + + def get_use_aggregate_fn( + use_aggregate_type: LogicalConnective, + ) -> Callable[[Iterator[Any]], bool]: + match use_aggregate_type: + case LogicalConnective.AND: + return all + case LogicalConnective.OR: + return any + case _: + return any + + use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type) + + def has_uses_impl(use: Node) -> bool: + if use.op != "call_function": + return False + if not ( + isinstance(use.target, torch._ops.OpOverload) + or use.target is operator.getitem + ): + return False + + target = cast(torch._ops.OpOverload, use.target) + # Process getitem and view + if target is operator.getitem or is_view(target): + return use_aggregate_fn(has_uses_impl(user) for user in use.users) + + return use_selector_fn(target) + + return use_aggregate_fn(has_uses_impl(user) for user in target.users) + + +def has_uses_tagged_as( + target: Node, + use_tags: Collection[torch.Tag], + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Is there a use with given tags? + """ + + return has_uses( + target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type + ) + + def gen_gm_and_inputs( target: Any, args: list[Any], kwargs: dict[str, Any] ) -> tuple[GraphModule, list[torch.Tensor]]: From 3144713325de01b478e9b469f546d61903cb570a Mon Sep 17 00:00:00 2001 From: clr Date: Mon, 3 Nov 2025 15:25:04 -0800 Subject: [PATCH 007/651] subproc_pool: Add support for enabling quiesce via a timer (#166467) This adds the capability to subproc pool to enable quiesce via a timer Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467 Approved by: https://github.com/masnesral --- test/inductor/test_compile_worker.py | 20 ++++++++++++++----- .../_inductor/compile_worker/subproc_pool.py | 13 ++++++++++++ torch/_inductor/compile_worker/timer.py | 2 +- torch/_inductor/config.py | 5 +++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 50a389e8663f9..7237d5a01c6b2 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -4,6 +4,7 @@ import tempfile from threading import Event +import torch._inductor.config as config from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, @@ -16,9 +17,12 @@ class TestCompileWorker(TestCase): + def make_pool(self, size): + return SubprocPool(size) + @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) b = pool.submit(operator.sub, 100, 1) @@ -29,7 +33,7 @@ def test_basic_jobs(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_exception(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(raise_testexc) with self.assertRaisesRegex( @@ -42,7 +46,7 @@ def test_exception(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_crash(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: with self.assertRaises(Exception): a = pool.submit(os._exit, 1) @@ -58,7 +62,7 @@ def test_crash(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_quiesce(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) pool.quiesce() @@ -75,7 +79,7 @@ def test_logging(self): os.environ["ROLE_RANK"] = "0" with tempfile.NamedTemporaryFile(delete=True) as temp_log: os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name - pool = SubprocPool(2) + pool = self.make_pool(2) try: pool.submit(operator.add, 100, 1) self.assertEqual(os.path.exists(temp_log.name), True) @@ -83,6 +87,12 @@ def test_logging(self): pool.shutdown() +@config.patch("quiesce_async_compile_time", 0.1) +class TestCompileWorkerWithTimer(TestCompileWorker): + def make_pool(self, size): + return SubprocPool(size, quiesce=True) + + class TestTimer(TestCase): def test_basics(self): done = Event() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 037b0e438adaa..a4114644026ca 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -24,6 +24,7 @@ import torch._thread_safe_fork # noqa: F401 from torch._inductor import config from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.timer import Timer from torch._inductor.compile_worker.tracked_process_pool import ( TrackedProcessPoolExecutor, ) @@ -132,6 +133,7 @@ def __init__( nprocs: int, pickler: Optional[SubprocPickler] = None, kind: SubprocKind = SubprocKind.FORK, + quiesce: bool = False, ) -> None: entry = os.path.join(os.path.dirname(__file__), "__main__.py") self.pickler = pickler or SubprocPickler() @@ -216,6 +218,13 @@ def __init__( "pytorch.wait_counter.subproc_pool.first_job" ).guard() + if quiesce: + self.timer: Optional[Timer] = Timer( + config.quiesce_async_compile_time, self.quiesce + ) + else: + self.timer = None + # Start thread last to ensure all member variables are initialized # before any access. self.read_thread.start() @@ -288,6 +297,8 @@ def _read_thread(self) -> None: with self.futures_lock: if not self.running: return + if self.timer: + self.timer.record_call() if isinstance(result, _SubprocExceptionInfo): # An exception occurred in the submitted job self.pending_futures[job_id].set_exception( @@ -322,6 +333,8 @@ def shutdown(self) -> None: with self.write_lock: if not self.running: return + if self.timer: + self.timer.quit() self.running = False self.running_waitcounter.__exit__() _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index d4b0c0dc9e281..7cfeb4217e26b 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -17,7 +17,7 @@ def __init__( self.background_thread: Optional[Thread] = None self.last_called: Optional[float] = None self.duration = duration - self.sleep_time = 60 + self.sleep_time = duration / 2 self.call = call self.exit = False diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b78ade758f80b..08cc2b2bd861a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -964,6 +964,11 @@ def decide_compile_threads() -> int: default=False, ) +# Time in seconds to wait before quiescing +quiesce_async_compile_time: int = Config( + default=60, +) + # Whether or not to enable statically launching CUDA kernels # compiled by triton (instead of using triton's own launcher) use_static_cuda_launcher: bool = static_cuda_launcher_default() From 527b1109a8a8d8ae9e1c76c057468aacb302ed84 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 4 Nov 2025 07:36:39 -0800 Subject: [PATCH 008/651] Delete deprecated fp32 precision warnings (#166956) The deprecation warning led to warning spamming in PyTorch APIs, like torch.compile. This is not how a deprecation warning should go: if we add a deprecation warning, we'd better update our built-in APIs to prevent warning spam. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956 Approved by: https://github.com/albanD --- aten/src/ATen/Context.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index a354b41912406..6bc321887502d 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP() #endif namespace at { -namespace { - /* These const variables defined the fp32 precisions for different backend We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 @@ -41,16 +39,6 @@ namespace { ->rnn */ - C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ - TORCH_WARN_ONCE( - "Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' " - "or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, " - "torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see " - "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" - ); - } -} // namespace - Float32Backend str2backend(const std::string& name) { if (name == "generic") return Float32Backend::GENERIC; @@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional op) const { } else { return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32; } - warn_deprecated_fp32_precision_api(); return allow_tf32_cudnn; } @@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) { setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE); setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE); allow_tf32_cudnn = b; - warn_deprecated_fp32_precision_api(); } void Context::setSDPPriorityOrder(const std::vector& order) { @@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const { "Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", "We suggest only using the new API to set the TF32 flag. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return allow_tf32_new; } @@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const { "Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", "We suggest only using the new API for matmul precision. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return float32_matmul_precision; } @@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) void Context::setFloat32MatmulPrecision(const std::string &s) { auto match = [this](const std::string & s_) { - warn_deprecated_fp32_precision_api(); // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention if (s_ == "highest") { float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; From 53f75cd5ba933148b21e4b1763a1a0790b0f3744 Mon Sep 17 00:00:00 2001 From: Wenlin Chong Date: Tue, 4 Nov 2025 18:18:34 +0000 Subject: [PATCH 009/651] Fixed some syntax errors in SECURITY.md file. (#166718) Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718 Approved by: https://github.com/mikaylagawarecki --- SECURITY.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index ed8228af36724..375f94547941f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,7 +1,7 @@ # Security Policy - [**Reporting a Vulnerability**](#reporting-a-vulnerability) - - [**Using Pytorch Securely**](#using-pytorch-securely) + - [**Using PyTorch Securely**](#using-pytorch-securely) - [Untrusted models](#untrusted-models) - [TorchScript models](#torchscript-models) - [Untrusted inputs](#untrusted-inputs) @@ -10,28 +10,28 @@ - [**CI/CD security principles**](#cicd-security-principles) ## Reporting Security Issues -Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch. +Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch. However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new -All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. +All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported: https://www.facebook.com/whitehat -## Using Pytorch Securely -**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). +## Using PyTorch Securely +**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). ### Untrusted models Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources]. **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). -**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. +**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. @@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de ### TorchScript models -TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. +TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. ### Untrusted inputs during training and prediction @@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some ### Data privacy -**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: -- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) -- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). +**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: +- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment) +- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits). ### Using distributed features From 496277a8ffcb29c9976fe93f91ab8232e29764b9 Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Tue, 4 Nov 2025 18:44:21 +0000 Subject: [PATCH 010/651] [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961) This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information: https://github.com/pytorch/pytorch/issues/166866 It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961 Approved by: https://github.com/jeffdaily --- .github/workflows/_rocm-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 43ed76a63cc67..608aeba53e6d8 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -97,8 +97,8 @@ jobs: shell: bash run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') - if [[ $ngpu -lt 4 ]]; then - echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" + if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus. + echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs" exit 1 fi From 1d3f5e19da068ec1340db041b7105b287a513578 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 4 Nov 2025 18:46:43 +0000 Subject: [PATCH 011/651] [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922) Fix and regression test for https://github.com/pytorch/pytorch/issues/165801 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922 Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: Andrey Talman --- .ci/docker/common/install_cuda.sh | 2 +- .ci/pytorch/smoke_test/smoke_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index fe2f9ae3185a3..fe0cb8cc79c4f 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -129,7 +129,7 @@ function install_129 { } function install_128 { - CUDNN_VERSION=9.8.0.87 + CUDNN_VERSION=9.10.2.21 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 675d58a3e283d..3642f29684cf0 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -272,6 +272,18 @@ def smoke_test_cuda( torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") + torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion() + print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}") + torch_cudnn_runtime_version = tuple( + [int(x) for x in torch_cudnn_version.split(".")] + ) + if torch_cudnn_runtime_version != torch_cudnn_compile_version: + raise RuntimeError( + "cuDNN runtime version doesn't match comple version. " + f"Loaded: {torch_cudnn_runtime_version} " + f"Expected: {torch_cudnn_compile_version}" + ) + if sys.platform in ["linux", "linux2"]: torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") From a5f3035aafd5113dd7641a95a3e919d4a4c8781f Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 10:46:38 -0800 Subject: [PATCH 012/651] More pyrefly local errors (#166976) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976 Approved by: https://github.com/maggiemoss, https://github.com/Skylion007 --- torch/_higher_order_ops/triton_kernel_wrap.py | 1 + torch/cuda/__init__.py | 2 +- torch/cuda/memory.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 8ffab37699422..0e398897a7eab 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -498,6 +498,7 @@ def get_signature_value(idx: int, arg: Any) -> str: # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() + # pyrefly: ignore[missing-argument,bad-argument-type] ttir_module = src.make_ir(options, codegen_fns, module_map, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index dff869742df56..23d297b6d95e0 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None): "nvidia-ml-py does not seem to be installed or it can't be imported." # pyrefly: ignore [invalid-inheritance] ) from _PYNVML_ERR - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 6834ffb5706a0..a1decc20cc9a8 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str: import pynvml # type: ignore[import] except ModuleNotFoundError: return "pynvml module not found, please install nvidia-ml-py" - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: From 52ea135f77f2469a8c15f2051260584ddd7c3bb8 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 3 Nov 2025 07:19:37 -0800 Subject: [PATCH 013/651] [BE] Delete Python-3.9 stdlib definitions from torch.package (#166768) And simplify the entire function to just assert and return Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768 Approved by: https://github.com/cyyever, https://github.com/atalman --- torch/package/_stdlib.py | 229 +-------------------------------------- 1 file changed, 2 insertions(+), 227 deletions(-) diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index 57a51ac41cfd9..e07b20a83cc6d 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -17,230 +17,5 @@ def is_stdlib_module(module: str) -> bool: def _get_stdlib_modules(): - if sys.version_info.major == 3: # noqa: UP036 - if sys.version_info.minor == 9: - return stdlib3_9 - if sys.version_info.minor >= 10: # noqa: YTT204 - return sys.stdlib_module_names # type: ignore[attr-defined] - elif sys.version_info.major > 3: # noqa: UP036 - return sys.stdlib_module_names # type: ignore[attr-defined] - - raise RuntimeError(f"Unsupported Python version: {sys.version_info}") - - -stdlib3_9 = { - "_thread", - "abc", - "aifc", - "argparse", - "array", - "ast", - "asynchat", - "asyncio", - "asyncore", - "atexit", - "audioop", - "base64", - "bdb", - "binascii", - "binhex", - "bisect", - "builtins", - "bz2", - "cProfile", - "calendar", - "cgi", - "cgitb", - "chunk", - "cmath", - "cmd", - "code", - "codecs", - "codeop", - "collections", - "colorsys", - "compileall", - "concurrent", - "configparser", - "contextlib", - "contextvars", - "copy", - "copyreg", - "crypt", - "csv", - "ctypes", - "curses", - "dataclasses", - "datetime", - "dbm", - "decimal", - "difflib", - "dis", - "distutils", - "doctest", - "email", - "encodings", - "ensurepip", - "enum", - "errno", - "faulthandler", - "fcntl", - "filecmp", - "fileinput", - "fnmatch", - "formatter", - "fractions", - "ftplib", - "functools", - "gc", - "getopt", - "getpass", - "gettext", - "glob", - "graphlib", - "grp", - "gzip", - "hashlib", - "heapq", - "hmac", - "html", - "http", - "imaplib", - "imghdr", - "imp", - "importlib", - "inspect", - "io", - "ipaddress", - "itertools", - "json", - "keyword", - "lib2to3", - "linecache", - "locale", - "logging", - "lzma", - "mailbox", - "mailcap", - "marshal", - "math", - "mimetypes", - "mmap", - "modulefinder", - "msilib", - "msvcrt", - "multiprocessing", - "netrc", - "nis", - "nntplib", - "ntpath", - "numbers", - "operator", - "optparse", - "os", - "ossaudiodev", - "parser", - "pathlib", - "pdb", - "pickle", - "pickletools", - "pipes", - "pkgutil", - "platform", - "plistlib", - "poplib", - "posix", - "posixpath", - "pprint", - "profile", - "pstats", - "pty", - "pwd", - "py_compile", - "pyclbr", - "pydoc", - "queue", - "quopri", - "random", - "re", - "readline", - "reprlib", - "resource", - "rlcompleter", - "runpy", - "sched", - "secrets", - "select", - "selectors", - "shelve", - "shlex", - "shutil", - "signal", - "site", - "smtpd", - "smtplib", - "sndhdr", - "socket", - "socketserver", - "spwd", - "sqlite3", - "sre", - "sre_compile", - "sre_constants", - "sre_parse", - "ssl", - "stat", - "statistics", - "string", - "stringprep", - "struct", - "subprocess", - "sunau", - "symbol", - "symtable", - "sys", - "sysconfig", - "syslog", - "tabnanny", - "tarfile", - "telnetlib", - "tempfile", - "termios", - "test", - "textwrap", - "threading", - "time", - "timeit", - "tkinter", - "token", - "tokenize", - "trace", - "traceback", - "tracemalloc", - "tty", - "turtle", - "turtledemo", - "types", - "typing", - "unicodedata", - "unittest", - "urllib", - "uu", - "uuid", - "venv", - "warnings", - "wave", - "weakref", - "webbrowser", - "winreg", - "winsound", - "wsgiref", - "xdrlib", - "xml", - "xmlrpc", - "zipapp", - "zipfile", - "zipimport", - "zlib", - "zoneinfo", -} + assert sys.version_info >= (3, 10) + return sys.stdlib_module_names From cef98ae5cbb484483b8cfe1b720e74fa10c7e720 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 4 Nov 2025 07:41:25 -0800 Subject: [PATCH 014/651] [aotd] Compiled saved tensor hooks context (#166887) Draft to expose compiled saved tensor hook context to selectively apply them. Exposing node, fw_graph, bw_graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166887 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 15 ++++ .../_functorch/_aot_autograd/graph_compile.py | 86 ++++++++++++++----- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index fba7a96288caf..b0dd1ff8fa75d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -167,6 +167,14 @@ def _pack_fp8_wrap(x): if not x.dtype.is_floating_point: return x + if type(x) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None + return (x.dtype, x.to(torch.float8_e5m2)) @@ -176,6 +184,13 @@ def _unpack_fp8_wrap(x): return x dtype, tensor = x + if type(tensor) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None return tensor.to(dtype) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 60ee3bc2973b1..b11eb87dc1720 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -25,6 +25,9 @@ if TYPE_CHECKING: from collections.abc import Sequence +import threading +from contextlib import contextmanager + import torch import torch.utils._pytree as pytree import torch.utils.dlpack @@ -97,6 +100,43 @@ ) +_thread_local = threading.local() + + +# Saved tensor hooks context +# Compiled saved tensor hooks are convenient way to inline some logic in the graphs +# for saved nodes from forward to backward. (E.g. activations quantization) +# In base implementation user does not have any additional information about saved value +# in the hook, except FakeTensor shape, dtype, device etc. +# _get_saved_tensor_hook_context gives additional graph information about that saved value, +# that can be used to make a decisions which pack/unpack to apply for particular saved value. +# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in +# graph aware way. +# Alternative to this will be making user to write a custom pass that mucks with forward outputs, +# backward input metadata, which requires significantly more effort. +# +# As for now in context we expose forward graph, backward graph and current saved node, +# which contains node.meta with additional information about that fx.Node. +# Warning: This API may change without backward compatibility. +@contextmanager +def _saved_tensor_hook_context(state: dict[str, Any]): + previous_state = getattr(_thread_local, "state", None) + try: + _thread_local.state = state + yield + finally: + # Clean up: restore previous state or remove attribute + if previous_state is not None: + _thread_local.state = previous_state + else: + if hasattr(_thread_local, "state"): + delattr(_thread_local, "state") + + +def _get_saved_tensor_hook_context() -> dict[str, Any] | None: + return getattr(_thread_local, "state", None) + + zip = strict_zip log = logging.getLogger(__name__) @@ -1097,7 +1137,11 @@ def _gen_unused_name(candidate: str): if not isinstance(val, torch.Tensor): continue - pack_out_val = pack_hook_gm(val) + def _get_extra_info() -> dict[str, Any]: + return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved} + + with _saved_tensor_hook_context(_get_extra_info()): + pack_out_val = pack_hook_gm(val) requires_sc_handling = any( is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) @@ -1109,16 +1153,17 @@ def _gen_unused_name(candidate: str): " in the pack hook, and reconstructing the subclass in the unpack hook" ) - pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) - pack_g = pack_gm.graph - maybe_log_graph( - pack_gm, - f"saved_tensors_pack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_pack {saved.name}", - structured_logs, - ) - pack_out_val = pack_gm(val) + with _saved_tensor_hook_context(_get_extra_info()): + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) # Install pack hook graph as eiplogue of fw_module. # Saved tensor output becomes input of pack hook graph. @@ -1188,15 +1233,16 @@ def _gen_unused_name(candidate: str): # Install unpack hook graph as a prologue of backward graph # Saved tensors inputs are replaced with packed tensors and packed sym scalars. # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. - unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) - unpack_g = unpack_gm.graph - maybe_log_graph( - unpack_gm, - f"saved_tensors_unpack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", - structured_logs, - ) + with _saved_tensor_hook_context(_get_extra_info()): + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) def find_saved_in_bw_inputs(bw_inputs): for n in bw_inputs: From d77c24caac4b42c56b4fa6a156ce85fb4907643e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Nov 2025 20:13:33 +0000 Subject: [PATCH 015/651] Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)" This reverts commit 0e1a88904f4a5e30634b196678b56e1d6ec074f5. Reverted https://github.com/pytorch/pytorch/pull/165036 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19059329909/job/54439919668) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/0e1a88904f4a5e30634b196678b56e1d6ec074f5) ([comment](https://github.com/pytorch/pytorch/pull/165036#issuecomment-3487846555)) --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 - setup.py | 34 -- test/inductor/test_cutedsl_grouped_mm.py | 154 -------- torch/_inductor/config.py | 4 - torch/_inductor/kernel/mm_common.py | 7 - torch/_inductor/kernel/mm_grouped.py | 93 ++--- .../templates/cutedsl_mm_grouped.py.jinja | 333 ------------------ .../_inductor/template_heuristics/cutedsl.py | 141 -------- torch/_inductor/utils.py | 71 ---- 10 files changed, 33 insertions(+), 807 deletions(-) delete mode 100644 test/inductor/test_cutedsl_grouped_mm.py delete mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja delete mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae2578758939..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index 3b4323051073a..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index dd8a52cbeb7c7..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a54099..0000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 08cc2b2bd861a..457f86fe7a77e 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,10 +546,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2afc..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 0a44b728a5a93..881c14fd43d0d 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -19,25 +18,19 @@ TritonTemplate, ) from ..utils import ( - ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) -if ensure_cute_available(): - from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs - log = logging.getLogger(__name__) aten = torch.ops.aten @@ -520,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -726,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -801,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80f..0000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a271..0000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 6b34ef28b2c10..2cf915d9e61de 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,77 +1975,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL is available - 2. We are on a blackwell arch - 3. The dtype is bf16 - 4. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 397d9fe2aea0dc60ea19ffddf6ac750420362867 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 4 Nov 2025 00:19:13 -0800 Subject: [PATCH 016/651] [inductor] coordesc not tune XBLOCK for mix-order-reduction (#166669) For mix-order reduction, we current force XBLOCK to be 1 to simplify codegen. Don't tune it in CDT. Differential Revision: [](https://our.internmc.facebook.com/intern/diff/) Differential Revision: [D86224689](https://our.internmc.facebook.com/intern/diff/D86224689) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166669 Approved by: https://github.com/jansel, https://github.com/mlazos, https://github.com/eellison, https://github.com/v0i0 --- test/inductor/test_mix_order_reduction.py | 16 ++++++++++++++++ .../runtime/coordinate_descent_tuner.py | 8 +++++++- torch/_inductor/runtime/triton_heuristics.py | 8 ++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 230a2514b9171..0dcc37ee359d8 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -117,6 +117,22 @@ def outer_red(): metrics.codegen_mix_order_reduction, ) + @inductor_config.patch(coordinate_descent_tuning=True) + def test_XBLOCK_coordest_tuning(self): + """ + We should skip XBLOCK coordinate descent tuning for + mix order reduction. + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=-1), x.sum(dim=0) + + x = torch.randn(32768, 256, dtype=torch.float, device=GPU_TYPE) + self.check_numeric(f, (x,)) + self.assertEqual(metrics.codegen_mix_order_reduction, 1) + @inductor_config.patch(unroll_reductions_threshold=1) def test_3layer_split_reduction(self): """ diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 341475ef1d6fb..7ea22bdcddf0b 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -5,6 +5,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING +from torch.utils._ordered_set import OrderedSet + from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -54,6 +56,7 @@ def __init__( name="unknown", size_hints=None, inductor_meta=None, + frozen_fields=None, ): self.is_mm = is_mm # we will tune num_stages for mm @@ -66,6 +69,9 @@ def __init__( self.name = name self.size_hints = size_hints self.inductor_meta = inductor_meta or {} + self.frozen_fields: OrderedSet[str] = ( + OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet() + ) def get_config_max(self, prefix: str) -> int: max_block = TRITON_MAX_BLOCK[prefix.upper()] @@ -117,7 +123,7 @@ def tunable_fields(self): out.append("num_stages") out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul - return out + return [f for f in out if f not in self.frozen_fields] def value_too_large(self, name: str, val: int) -> bool: block_suffix = "BLOCK" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index fe6788fb21e91..cb43d55bc86b3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -336,6 +336,7 @@ def __init__( name=self.fn.__name__, size_hints=size_hints, inductor_meta=self.inductor_meta, + frozen_fields=self.get_coordesc_frozen_fields(), ) self.filename = filename @@ -365,6 +366,13 @@ def __init__( # Mode for launch grid calculation self.grid_mode: Literal["python", "cpp"] = "python" + def get_coordesc_frozen_fields(self) -> OrderedSet[str]: + out: OrderedSet[str] = OrderedSet() + if self.inductor_meta.get("RSPLIT_SIZE"): + # We fix XBLOCK for mix order reduction + out.add("XBLOCK") + return out + def is_statically_launchable(self): """ Checks if every compiled kernel is statically launchable, which From 3283eaa5ba901b518fe971e3a35434982034e061 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Tue, 4 Nov 2025 20:33:56 +0000 Subject: [PATCH 017/651] Upload test stats for trunk/sha tag (#166916) Noticed that workflow runs for `trunk/{sha}` tags (issued by autorevert) don't populate test_run_s3 Clickhouse table. This PR is addressing this by changing the gate condition to upload tests stats. see https://github.com/pytorch/pytorch/actions/runs/19054297956/job/54421254448#step:8:23 as an evidence that HEAD_BRANCH is correctly populated for trunk tags. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166916 Approved by: https://github.com/huydhn, https://github.com/clee2000 --- tools/stats/upload_test_stats.py | 17 ++++++++++++++++- tools/test/test_upload_gate.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tools/test/test_upload_gate.py diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index b5802e8032419..6c0232c5e5a17 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -2,6 +2,7 @@ import argparse import os +import re import sys import xml.etree.ElementTree as ET from multiprocessing import cpu_count, Pool @@ -19,6 +20,19 @@ ) +def should_upload_full_test_run(head_branch: str | None, head_repository: str) -> bool: + """Return True if we should upload the full test_run dataset. + + Rules: + - Only for the main repository (pytorch/pytorch) + - If head_branch is 'main', or a tag of form 'trunk/{40-hex-sha}' + """ + is_trunk_tag = bool(re.fullmatch(r"trunk/[0-9a-fA-F]{40}", (head_branch or ""))) + return head_repository == "pytorch/pytorch" and ( + head_branch == "main" or is_trunk_tag + ) + + def parse_xml_report( tag: str, report: Path, @@ -287,7 +301,8 @@ def init_value(test_case: dict[str, Any]) -> dict[str, Any]: remove_nan_inf(failed_tests_cases), ) - if args.head_branch == "main" and args.head_repository == "pytorch/pytorch": + # Upload full test_run only for trusted refs (main or trunk/{sha} tags) + if should_upload_full_test_run(args.head_branch, args.head_repository): # For jobs on main branch, upload everything. upload_workflow_stats_to_s3( args.workflow_run_id, diff --git a/tools/test/test_upload_gate.py b/tools/test/test_upload_gate.py new file mode 100644 index 0000000000000..7d9a2e5fe3b0b --- /dev/null +++ b/tools/test/test_upload_gate.py @@ -0,0 +1,28 @@ +import unittest + +from tools.stats.upload_test_stats import should_upload_full_test_run + + +class TestUploadGate(unittest.TestCase): + def test_main_branch_on_pytorch_repo(self) -> None: + self.assertTrue(should_upload_full_test_run("main", "pytorch/pytorch")) + + def test_trunk_tag_valid_sha_on_pytorch_repo(self) -> None: + sha = "a" * 40 + self.assertTrue(should_upload_full_test_run(f"trunk/{sha}", "pytorch/pytorch")) + + def test_trunk_tag_invalid_sha_on_pytorch_repo(self) -> None: + # Not 40 hex chars + self.assertFalse(should_upload_full_test_run("trunk/12345", "pytorch/pytorch")) + + def test_non_main_branch_on_pytorch_repo(self) -> None: + self.assertFalse( + should_upload_full_test_run("feature-branch", "pytorch/pytorch") + ) + + def test_main_branch_on_fork_repo(self) -> None: + self.assertFalse(should_upload_full_test_run("main", "someone/fork")) + + +if __name__ == "__main__": + unittest.main() From b4e4ee81d386db922d8f63359f9870eff1f44052 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 4 Nov 2025 20:34:11 +0000 Subject: [PATCH 018/651] Update triton to 3.5.1 release (#166968) This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968 Approved by: https://github.com/Lucaskabela, https://github.com/njriasan --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 10f1207e60e6c..7aab8bed1c108 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571dc..d5c0c99142898 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.0 +3.5.1 From 2bba37309bc8996fc6a190592e5ad9aac53761c9 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 4 Nov 2025 09:52:20 -0800 Subject: [PATCH 019/651] [inductor] runtime estimations disable use_nccl_estimator by default (#166973) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166973 Approved by: https://github.com/eellison, https://github.com/jathu --- torch/_inductor/comm_analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index afa569ff97da2..61af576772c16 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -359,7 +359,8 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, - use_nccl_estimator: bool = True, + # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. + use_nccl_estimator: bool = False, ) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). From 871d0cd19651ce569fbc1b5dbc28195f8ae78315 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 10:46:42 -0800 Subject: [PATCH 020/651] If USE_CUDA=1 is set, do not fallback to no CUDA (#166982) So many times i build pytorch only to notice chef nuked my nvcc and i wasted 30m building a cpu version, lets hard error fast Pull Request resolved: https://github.com/pytorch/pytorch/pull/166982 Approved by: https://github.com/malfet ghstack dependencies: #166976 --- CMakeLists.txt | 10 ++++++++++ cmake/public/cuda.cmake | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bbb8797b78cd..86f43f58817ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,7 +234,17 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_LSAN "Use Leak Sanitizer" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) + +# Track whether USE_CUDA was explicitly set by the user (before option() is called) +# If USE_CUDA is already defined in cache, it means user explicitly set it +if(DEFINED CACHE{USE_CUDA}) + set(_USE_CUDA_EXPLICITLY_SET TRUE) +else() + set(_USE_CUDA_EXPLICITLY_SET FALSE) +endif() + option(USE_CUDA "Use CUDA" ON) + option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 218c50a69c6fb..bc8855d23e61f 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -28,6 +28,15 @@ endif() # Find CUDA. find_package(CUDA) if(NOT CUDA_FOUND) + # If user explicitly set USE_CUDA=1, error out instead of falling back + if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA) + message(FATAL_ERROR + "PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. " + "Please check your CUDA installation, ensure CUDA toolkit is installed, " + "and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. " + "If you want to build without CUDA, please set USE_CUDA=0.") + endif() + message(WARNING "PyTorch: CUDA cannot be found. Depending on whether you are building " "PyTorch or a PyTorch dependent library, the next warning / error will " From 4e1bd1673855356402ffcee4d254129f7848f402 Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Tue, 4 Nov 2025 10:05:48 -0800 Subject: [PATCH 021/651] inductor: Switch quiesce to use timer based implementation. (#166581) Major change is to switch to a timer based implementation. Additionally, we get rid of the context manager for turning of the compile pool. We still have the warmup calls. Note that this only modifies the async_compile methods, the fx pool is left running. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166581 Approved by: https://github.com/masnesral ghstack dependencies: #166467 --- torch/_dynamo/convert_frame.py | 2 -- .../_aot_autograd/runtime_wrappers.py | 3 -- torch/_inductor/async_compile.py | 31 ++----------------- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 875f640194e42..4439c7dc09efe 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1285,7 +1285,6 @@ def _compile( # in the case of normal and exception code paths convert_frame_box: Optional[ConvertFrameBox] = None, ) -> ConvertFrameReturn: - from torch._inductor.async_compile import async_compile_pool_manager from torch.fx.experimental.validator import ( BisectValidationException, ValidationException, @@ -1479,7 +1478,6 @@ def count_args(code: CodeType) -> int: with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), - async_compile_pool_manager(), chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 4846f1ca74edb..86202e2cd319d 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -2365,8 +2365,6 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): - from torch._inductor.async_compile import async_compile_pool_manager - # compiled autograd reimplements this function at proxy_call_aot_backward assert not backward_state_indices, ( "BackwardState requires CompiledAutograd" @@ -2446,7 +2444,6 @@ def _backward_impl(ctx, all_args): with ( tracing(saved_context), compile_context(saved_compile_context), - async_compile_pool_manager(), context(), track_graph_compiling(aot_config, "backward"), metrics_context, diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index ac0d60bdebd71..a2c80002eb928 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -2,7 +2,6 @@ from __future__ import annotations import atexit -import contextlib import functools import json import logging @@ -230,18 +229,6 @@ def remove_future(kernel_src: str) -> None: del CompiledTritonKernels._cache[key] -@contextlib.contextmanager -def async_compile_pool_manager(): - """ - Context manager to quiesce the subproc pool at the end of compilation, i.e., - when dynamo is done. - """ - try: - yield - finally: - AsyncCompile.quiesce() - - class AsyncCompile: """ Utilities to compile in thread pools or subprocess pools (in the case of Triton). @@ -277,7 +264,9 @@ def process_pool() -> AnyPool: pool: AnyPool if config.worker_start_method == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(get_compile_threads()) + pool = SubprocPool( + get_compile_threads(), quiesce=config.quiesce_async_compile_pool + ) else: if config.worker_start_method == "spawn": # Avoid creating pools in the spawned subprocs themselves: @@ -333,20 +322,6 @@ def use_process_pool(cls): cls._ready_future = cls.process_pool().submit(cls._get_ready) return cls._ready_future.done() - @classmethod - def quiesce(cls) -> None: - """ - If using a SubprocPool, signal the sidecar process to shut down its - ProcessPoolExecutor. - """ - # Don't inadvertently create a process pool if it doesn't already exist: - if not cls.process_pool.cache_info().currsize: - return - if config.quiesce_async_compile_pool: - pool = cls.process_pool() - if isinstance(pool, SubprocPool): - pool.quiesce() - @classmethod def wakeup(cls) -> None: """ From 2673f8b00705d9dd537f2bfcce6a5a1dbf4b2a31 Mon Sep 17 00:00:00 2001 From: Parshant Sharma Date: Tue, 4 Nov 2025 21:06:55 +0000 Subject: [PATCH 022/651] Fix torch.linalg.eig inductor stride mismatch (#162484) Fixes #159445 ### Summary - Fixed a stride layout issue in the `torch.linalg.eig` meta kernel that prevented successful compilation with the inductor backend. The meta kernel was producing incorrect row-major strides. - LAPACK/BLAS libraries (underlying implementation) expect column-major layout Pull Request resolved: https://github.com/pytorch/pytorch/pull/162484 Approved by: https://github.com/isuruf --- test/inductor/test_torchinductor.py | 16 ++++++++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 3 +++ torch/_meta_registrations.py | 4 ++++ 3 files changed, 23 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index dad2de9bde327..ed8993a1c9a39 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5876,6 +5876,22 @@ def fn(x, y): reference_in_float=False, ) + @skipIfMPS + def test_linalg_eig_stride_consistency(self): + def fn(x): + eigenvals, eigenvecs = torch.linalg.eig(x) + return eigenvecs + + x = torch.randn(5, 5, device=self.device, dtype=torch.float32) + + self.common( + fn, + [x], + exact_stride=True, + exact_dtype=True, + check_lowp=False, + ) + def test_view_as_complex(self): class Repro(torch.nn.Module): def __init__(self) -> None: diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 2244af38f635a..e73f82ab64911 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -159,6 +159,9 @@ def run(*ex, **kwargs): # "test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_linalg_eig_stride_consistency_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu") + ), "test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index f84b77e630bf3..fe0492ff19c1c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1021,6 +1021,10 @@ def meta_linalg_eig(input: Tensor): ) values = input.new_empty(input.shape[:-1], dtype=complex_dtype) vectors = input.new_empty(input.shape, dtype=complex_dtype) + is_cuda = device_hint(input) == "cuda" + vectors.as_strided_( + input.shape, make_contiguous_strides_for(input.shape, row_major=is_cuda) + ) return values, vectors From 7f0e9321360cb13563a11bf9c720464c3dbf1ece Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 4 Nov 2025 10:42:01 -0800 Subject: [PATCH 023/651] [dynamo] don't use LocalSource for temp variables created by side_effects (#166917) Fixes https://github.com/pytorch/pytorch/issues/166900 Implementation notes: - I tried to disallow guard generation before side effect application in order to futureproof improper guard generation. However, this was not feasible since it is possible to realize lazy VTs while generating side effects (e.g. realizing a constant variable that is used in a deque update). - `codegen_save_tempvars` now generates `TempLocalSource` for create temporary variables now, so that they won't get confused with `LocalSource` - we should error out when we attempt to create guards for `TempLocalSource`. I considered using `SyntheticLocalSource`, but that has additional `subguards_allowed` behavior that we may not want to have for temp variables. - We moved the guard installation for constant user-defined pytree objects from `as_python_constant` to `__init__`. Objects created outside the compile-region will be guarded, while objects created inside the compile-region will not be guarded. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166917 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 21 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 8 ++++---- torch/_dynamo/source.py | 17 +++++++++++++++++ torch/_dynamo/variables/user_defined.py | 16 ++++++++++------ torch/_guards.py | 1 + 5 files changed, 53 insertions(+), 10 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 169f43ce0a077..b8727208a5bfa 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13219,6 +13219,27 @@ def mapper(x): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) + def test_pytree_register_constant_with_side_effect(self): + class Foo: + pass + + class Bar: + def __eq__(self, other): + return super().__eq__(other) + + def __hash__(self): + return 0 + + python_pytree.register_constant(Bar) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x, obj): + obj.attr = {3: Bar()} + return x + 1 + + inp = torch.ones(3) + self.assertEqual(fn(inp, Foo()), inp + 1) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index bd38e9295a05a..688a05f26ae64 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -42,7 +42,7 @@ ) from .codegen import PyCodegen from .exc import SideEffectsError, unimplemented_v2 -from .source import GlobalSource, LocalCellSource, LocalSource, Source +from .source import GlobalSource, LocalCellSource, Source, TempLocalSource from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( AttributeMutation, @@ -704,7 +704,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: ) cg.extend_output(create_call_function(0, False)) cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + var.source = TempLocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: # pyrefly: ignore [bad-assignment] var.source = LocalCellSource(var.local_name) @@ -729,7 +729,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: # `add_cache` generates STORE and consumes TOS, but we never # cleared it. TODO move this call into `add_cache` cg.clear_tos() - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) elif isinstance(var, variables.AutogradFunctionContextVariable): unimplemented_v2( gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", @@ -764,7 +764,7 @@ def load_new_method() -> None: cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) for ctx, args in self.save_for_backward: cg(ctx.source) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 8edd8f7540e31..5be6b8ccbf41d 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -151,6 +151,23 @@ def name(self) -> str: return f"L[{repr(self.local_name)}]" +@dataclasses.dataclass(frozen=True) +class TempLocalSource(Source): + # like LocalSource, but cannot be guarded on + local_name: str + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self) -> GuardSource: + return GuardSource.TEMP_LOCAL + + def name(self) -> str: + raise NotImplementedError( + "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub." + ) + + @dataclasses.dataclass(frozen=True) class SyntheticLocalSource(Source): local_name: str diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 707ad7b3d9d18..085b5e0c648c5 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -968,6 +968,12 @@ def __init__( # rid of these workarounds here and in `GetAttrVariable`. self.attrs_directly_modifed_on_dict = set() + import torch.utils._pytree as pytree + + self.is_pytree_constant_class = pytree.is_constant_class(self.value_type) + if pytree.is_constant_class(self.value_type) and self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + def __str__(self) -> str: inner = self.value_type.__name__ if inner in [ @@ -989,12 +995,10 @@ def python_type(self): return self.value_type def as_python_constant(self): - import torch.utils._pytree as pytree - - if pytree.is_constant_class(self.value_type): - if self.source is not None: - install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) - return self.value + if self.is_pytree_constant_class and self.source: + # NOTE pytree constants created in the torch.compile region will + # NOT be guarded (even though they have a source set) + return self.value # TODO else try reconstructing the object by, e.g., leveraging side # effects and `as_python_constant`. return super().as_python_constant() diff --git a/torch/_guards.py b/torch/_guards.py index bac59965a3aef..b321c5f968b16 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -145,6 +145,7 @@ class GuardSource(enum.Enum): GLOBAL_UNSPECIALIZED_NN_MODULE = 13 LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14 GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15 + TEMP_LOCAL = 16 def is_fsdp_module(self) -> bool: return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) From ed45c5f38df6aa419c67d139d932c2c94404223a Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 4 Nov 2025 09:14:26 -0800 Subject: [PATCH 024/651] Avoid DDE in narrow with unbacked start (#166361) Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste --- aten/src/ATen/native/TensorShape.cpp | 38 +++++++++++++++--- c10/core/SymBool.cpp | 14 +++++++ c10/core/SymBool.h | 6 +++ test/export/test_export.py | 31 +++++++++----- test/test_dynamic_shapes.py | 51 ++++++++++++++++++++++++ test/test_torchfuzz_repros.py | 5 ++- torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 ++++++++- torch/utils/_sympy/printers.py | 36 +++++++++++++++++ 9 files changed, 184 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822db..6136a6aa8c520 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,14 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1726,31 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) { + return at::slice_symint(self, dim, start, end, 1); + } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) { + // Avoid the complex symbolic expressions path for non-unbacked. + return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d27409..48c407b8b069c 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b1d..a27a28a5bf8a3 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e55..cdc18b1d4c564 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d50a..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,57 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 3b864aae4f477..84a00430420cf 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -16,6 +16,10 @@ from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +# Skip all tests in this file if CUDA is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + class TestFuzzerCompileIssues(TestCase): """Test cases for fuzzer-discovered eager/compile divergence issues.""" @@ -257,7 +261,6 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163971") def test_fuzzer_issue_163971(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e629d9c7bdebd..947166cf216cd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,7 +2063,8 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000db..693d25aea6130 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3f8..915d0e5461f1e 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From cdca63db8c0f30a6fcc181784411bbd8913aa1db Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 4 Nov 2025 21:28:14 +0000 Subject: [PATCH 025/651] Fix quoting in pytest_cache.py invocations (#166955) Especially the job identifier can contain spaces so needs to be quoted Fixes e.g. https://github.com/pytorch/pytorch/actions/runs/19063797853/job/54449422160#step:15:52 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166955 Approved by: https://github.com/Skylion007 --- .github/actions/pytest-cache-download/action.yml | 12 ++++++------ .github/actions/pytest-cache-upload/action.yml | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/actions/pytest-cache-download/action.yml b/.github/actions/pytest-cache-download/action.yml index 1406f962c4ca8..3f51f6a5525bc 100644 --- a/.github/actions/pytest-cache-download/action.yml +++ b/.github/actions/pytest-cache-download/action.yml @@ -38,9 +38,9 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --download \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --temp_dir $RUNNER_TEMP \ - --repo $REPO \ - --bucket $BUCKET \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --temp_dir "$RUNNER_TEMP" \ + --repo "$REPO" \ + --bucket "$BUCKET" \ diff --git a/.github/actions/pytest-cache-upload/action.yml b/.github/actions/pytest-cache-upload/action.yml index 2652d019075f7..9fbb63a760f27 100644 --- a/.github/actions/pytest-cache-upload/action.yml +++ b/.github/actions/pytest-cache-upload/action.yml @@ -47,11 +47,11 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --upload \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --sha $SHA \ - --test_config $TEST_CONFIG \ - --shard $SHARD \ - --repo $REPO \ - --temp_dir $RUNNER_TEMP \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --sha "$SHA" \ + --test_config "$TEST_CONFIG" \ + --shard "$SHARD" \ + --repo "$REPO" \ + --temp_dir "$RUNNER_TEMP" \ From a64c7d740428010d700b4bcd395af8a7b2d5c21f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 4 Nov 2025 21:30:43 +0000 Subject: [PATCH 026/651] [DebugMode] output, tensor id annotations for DebugMode (#165076) Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)` Example output for `test_debug_mode_mm`, with both enabled: ``` torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0) aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) redistribute_input(t$4: f32[1, 32], trace: S(0)->R) _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32] _c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32] aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32] (dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P aten::sum(dt$14: f32[8, 32]| S(0)) aten::sum(t$15: f32[1, 32]) -> t$16: f32[]""" ``` Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076 Approved by: https://github.com/zpcore --- .../tensor/debug/test_debug_mode.py | 22 +-- torch/utils/_debug_mode.py | 126 +++++++++++++++--- 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 07442f34c8946..9acfcb15804e5 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -42,22 +42,24 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_ids=True, record_output=True + ) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) - aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) + aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) - _c10d_functional::wait_tensor(t: f32[8, 32]) - aten::mm(t: f32[1, 8], t: f32[8, 32]) - (dt: f32[8, 32]| S(0)) - aten::sum(dt: f32[8, 32]| S(0)) - aten::sum(t: f32[1, 32])""", + redistribute_input(t$2: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] + _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] + aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + aten::sum(dt$6: f32[8, 32]| S(0)) + aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 09435aa07e68b..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,6 +2,7 @@ import contextlib import functools import traceback +import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -14,6 +15,7 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -56,29 +58,48 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -def _tensor_debug_string(tensor, attributes) -> str: +class TensorIdTracker: + def __init__(self): + self.tensor_memo: dict[WeakIdRef, int] = {} + self.next_tensor_id = 0 + + def _id(self, tensor) -> int: + with torch._C._DisablePythonDispatcher(): + o = WeakIdRef(tensor) + + def del_memo(): + self.tensor_memo.pop(o, None) + + weakref.finalize(tensor, del_memo) + if o not in self.tensor_memo: + self.tensor_memo[o] = self.next_tensor_id + self.next_tensor_id += 1 + return self.tensor_memo[o] + + +def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - + id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft: {tensor_debug_str}" + return f"ft{id_str}: {tensor_debug_str}" else: - return f"t: {tensor_debug_str}" + return f"t{id_str}: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes) -> str: +def _arg_to_str(arg, attributes, tensor_memo=None) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes) + return _tensor_debug_string(x, attributes, tensor_memo) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -144,8 +165,11 @@ def __init__( # results from dispatch hooks self.record = record self.log = log + self.output_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -153,6 +177,18 @@ def stringify_args(self, attributes: list[str]) -> None: "Subclasses must implement stringify_args(), even if no-op" ) + def stringify_output( + self, + output: Any, + attributes: list[str], + tensor_memo: Optional[TensorIdTracker] = None, + ) -> None: + """Store stringified version of call output in self.output_str""" + if tree_all(lambda x: x is None, output): + return + output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) + self.output_str = f" -> {str(output_str)}" + def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -179,11 +215,16 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.args_str = ", ".join( + _arg_to_str(arg, attributes, tensor_memo) for arg in self.args + ) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes, tensor_memo)}" + for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -215,6 +256,8 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" + if self.output_str: + base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -247,8 +290,10 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -263,7 +308,11 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + + base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + if self.output_str: + base_str += self.output_str + return base_str def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -288,7 +337,9 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -341,6 +392,8 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, + record_output=False, + record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -378,8 +431,24 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace + # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) + self.record_output: bool = record_output + + # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. + self.record_ids: bool = record_ids + + self.reset() + + def reset(self): self.operators = [] self.call_depth = 0 + self._tensor_memo = TensorIdTracker() + self._output_info: dict[int, object] = {} + + def _track_op_output(self, op_index, result): + """Assign IDs to output tensors and store in output_info""" + # self._track_tensor_ids(result) + self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -390,20 +459,35 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args(self.record_tensor_attributes) + call.stringify_args( + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) self.operators.append(call) + def _record_call_output(self, call, output): + if not self.record_output: + return + call.stringify_output( + output, + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self._record_call( - _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace ) + self._record_call(call) try: self.call_depth += 1 - return func(*args, **kwargs) + result = func(*args, **kwargs) + self._record_call_output(call, result) + return result finally: self.call_depth -= 1 @@ -445,13 +529,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: + self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.operators = [] - self.call_depth = 0 + self.reset() if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From e8052f2f99de1fb7284e38082ff5714e17cd9562 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 4 Nov 2025 11:20:38 -0800 Subject: [PATCH 027/651] Add model code stack trace to torch.profile (#166677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677 Approved by: https://github.com/ezyang ghstack dependencies: #166676 --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 425 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index d6f33d426aee7..c16c42805b921 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,6 +75,12 @@ ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4187,6 +4223,150 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..4b8a6d221b4e0 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..fd6835d2b301b 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -798,6 +799,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -807,8 +812,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1760,6 +1779,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1827,6 +1847,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1839,6 +1860,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1849,6 +1871,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From e020fb3431371ea335a0d5db5094810c9f1e104d Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 4 Nov 2025 22:09:24 +0000 Subject: [PATCH 028/651] [Minor][Inductor] move some combo kernel log from warning to debug (#166993) Combo kernel warns for long reduction and large pointwise. This becomes too spammy for users such as vLLM. This PR moves these logs from warn to debug. I validated the spammy log is removed on llama-3.1-8B. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166993 Approved by: https://github.com/zou3519, https://github.com/eellison --- torch/_inductor/codegen/triton_combo_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index e86753348c6b1..3e58e95ef9e9c 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -98,7 +98,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: - log.warning( + log.debug( "ComboKernels: %d long reduction nodes are separated", len(long_reduction), ) @@ -112,7 +112,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] if large_pointwise: # TODO benchmark the performance when large pointwise nodes combining with others - log.warning( + log.debug( "ComboKernels: %d large pointwise nodes are separated", len(large_pointwise), ) From 81038fd3268074a43d0b8fc4de9cf22a6d71a896 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Nov 2025 22:26:35 +0000 Subject: [PATCH 029/651] Revert "Add model code stack trace to torch.profile (#166677)" This reverts commit e8052f2f99de1fb7284e38082ff5714e17cd9562. Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/malfet due to Broke lint, please rebase, we've moved from mypy to pyrefly ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3488219996)) --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ------------------ torch/autograd/profiler_util.py | 40 ---- torch/fx/graph.py | 23 --- torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +--------------- 6 files changed, 5 insertions(+), 425 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 12f6ba2228db8..a404e15a977ee 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index c16c42805b921..d6f33d426aee7 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,12 +75,6 @@ ) from torch.testing._internal.jit_utils import JitTestCase -import json -import tempfile -from torch.profiler import profile, ProfilerActivity -from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace -from torch.autograd.profiler_util import _canonicalize_profiler_events - try: from torchvision import models as torchvision_models @@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor): print(x) -def _enrich_profiler_traces(prof): - """ - Helper function to extract and augment profiler events with stack traces. - - Args: - prof: A torch.profiler.profile object - - Returns: - A string representing enriched events - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: - trace_file = f.name - prof.export_chrome_trace(trace_file) - - with open(trace_file) as f: - trace_data = json.load(f) - - map_recorded_events_to_aten_ops_with_stack_trace( - trace_data - ) - - events = [] - for event in trace_data["traceEvents"]: - if "args" in event and "stack_trace" in event["args"]: - events.append(event) - - actual_traces = _canonicalize_profiler_events(events) - return actual_traces - - class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4223,150 +4187,6 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_stack_trace_augmentation(self): - """ - Test that map_recorded_events_to_aten_ops_with_stack_trace correctly - augments profiler events with stack traces from FX metadata registry. - """ - - # Simple test model - class TestModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(16, 10) - - def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x - - model = TestModel().cuda() - - # Compile the model - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda")) - - # Profile with the compiled model - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - - self.assertExpectedInline(actual_traces, """\ -event=aten::t node=t stack_trace=x = self.linear1(x) -event=aten::transpose node=t stack_trace=x = self.linear1(x) -event=aten::as_strided node=t stack_trace=x = self.linear1(x) -event=aten::addmm node=addmm stack_trace=x = self.linear1(x) -event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) -event=aten::relu node=relu stack_trace=x = self.relu(x) -event=aten::clamp_min node=relu stack_trace=x = self.relu(x) -event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) -event=aten::t node=t_1 stack_trace=x = self.linear2(x) -event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) -event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) -event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) -event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_multiple_modules(self): - """ - Test that multiple compiled modules under the same profiler session - have their events correctly augmented with stack traces. - """ - - class ModelA(torch.nn.Module): - def forward(self, x): - return x + 1 - - class ModelB(torch.nn.Module): - def forward(self, x): - return x - 1 - - model_a = ModelA().cuda() - model_b = ModelB().cuda() - - # Compile both models - compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) - compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_a(torch.randn(10, 10, device="cuda")) - _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - # Profile both models in the same session - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result_a = compiled_a(torch.randn(10, 10, device="cuda")) - result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::add node=add stack_trace=return x + 1 -event=cudaLaunchKernel node=add stack_trace=return x + 1 -event=aten::sub node=sub stack_trace=return x - 1 -event=cudaLaunchKernel node=sub stack_trace=return x - 1""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_nested_graph_modules(self): - """ - Test that nested graph modules (e.g., graph modules calling subgraphs) - have their events correctly augmented with stack traces. - """ - - # Model with nested structure - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.c = 5 - - @torch.compiler.nested_compile_region - def forward(self, x, y): - m = torch.mul(x, y) - s = m.sin() - a = s + self.c - return a - - model = Mod().cuda() - - # Compile the model (this may create nested graph modules) - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - # Profile - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::mul node=mul stack_trace=m = torch.mul(x, y) -event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) -event=aten::sin node=sin stack_trace=s = m.sin() -event=cudaLaunchKernel node=sin stack_trace=s = m.sin() -event=aten::add node=add stack_trace=a = s + self.c -event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" - ) - def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 4b8a6d221b4e0..b2d6530049e61 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,43 +1224,3 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) - - -# Collect all events with stack traces and format them canonically -def _canonicalize_profiler_events(events): - """ - Extract and format all events with stack traces in a canonical way - for deterministic testing. - """ - events_with_traces = [] - - for event in events: - # Extract relevant fields - event_name = event.get("name", "") - node_name = event["args"].get("node_name", "") - stack_trace = event["args"].get("stack_trace", "") - - # Get the last non-empty line of the stack trace - lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] - stack_trace = lines[-1] if lines else "" - - events_with_traces.append( - { - "event_name": event_name[:20], - "node_name": node_name, - "stack_trace": stack_trace, - "start_time": event.get("ts", 0), - } - ) - - # Sort by node_name for deterministic ordering - events_with_traces.sort(key=lambda x: x["start_time"]) - - # Format as a string - lines = [] - for evt in events_with_traces: - lines.append( - f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" - ) - - return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fd6835d2b301b..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,7 +443,6 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -799,10 +798,6 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") - if record_func: - body.append( - "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" - ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -812,22 +807,8 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") - do_record = record_func and node.op in ( - "call_function", - "call_method", - "call_module", - ) - if do_record: - # The double hash ## convention is used by post-processing to find the fx markers - body.append( - f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" - ) emit_node(node) delete_unused_values(node) - if do_record: - body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") - if record_func: - body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,7 +1760,6 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1847,7 +1827,6 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def _python_code( @@ -1860,7 +1839,6 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1871,7 +1849,6 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,18 +861,14 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - - from torch._dynamo import config as dynamo_config - - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata - ) + python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -889,6 +885,7 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -908,13 +905,6 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) - # Replace the placeholder in generated code with actual filename - # The double hash ## convention is used by post-processing to find the fx markers - self._code = self._code.replace( - "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", - f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", - ) - cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c6e06b2cb3c9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None: with profile(): pass - - -@dataclass -class TimelineEvent: - """Represents an event in the profiler timeline.""" - - timestamp: int - event_type: Literal["start", "end", "regular"] - marker_type: Optional[Literal["filename", "node"]] - identifier: Optional[str | int] - event: dict[str, Any] - - -@dataclass -class ContextStackEntry: - """Represents a context (filename or node) in the stack.""" - - context_type: Literal["filename", "node"] - identifier: str | int - metadata: Optional[dict] - tid: Optional[int] = None # Thread ID associated with this context - - -def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): - """ - Maps recorded profiler events to their corresponding fx nodes and adds stack traces. - - Builds a timeline of all events (regular ops and FX markers for filenames/nodes), - sorts by timestamp, then processes chronologically while maintaining a context stack of active - filename/node scopes. Regular events are augmented with stack traces and node names from the - innermost active context. Runtime is O(n log n) for n events. - - Args: - traced_data: Json of profiler events from Chrome trace - - Returns: - Dict mapping recorded event names to their aten operations with added stack traces - """ - from torch.fx.traceback import _FX_METADATA_REGISTRY - - trace_events = traced_data.get("traceEvents", []) - - # Create event timeline - event_timeline: list[TimelineEvent] = [] - - def is_fx_marker_event(event): - return ( - event.get("cat") == "cpu_op" - and event.get("name", "").startswith("## ") - and event.get("name", "").endswith(" ##") - ) - - def append_fx_marker_event(event_type, identifier, event): - start_ts = event["ts"] - end_ts = start_ts + event["dur"] - event_timeline.append( - TimelineEvent(start_ts, "start", event_type, identifier, event) - ) - event_timeline.append( - TimelineEvent(end_ts, "end", event_type, identifier, event) - ) - - for event in trace_events: - if "ts" not in event or "dur" not in event: - continue - - if is_fx_marker_event(event): - content = event["name"][3:-3] - - if content.endswith(".py"): - append_fx_marker_event("filename", content, event) - else: - try: - node_index = int(content) - except ValueError: - pass - append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] - - else: - # Regular event that needs augmentation - start_ts = event["ts"] - event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) - - # Sort by timestamp - event_timeline.sort(key=lambda x: x.timestamp) - - # Process events in chronological order with a stack - context_stack: list[ContextStackEntry] = [] - - # Invariant: all start event has a corresponding end event - for timeline_event in event_timeline: - match timeline_event.event_type: - case "start": - assert timeline_event.identifier is not None - - if timeline_event.marker_type == "filename": - assert isinstance(timeline_event.identifier, str) - # Push filename context - query metadata registry on-demand - metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) - tid = timeline_event.event.get("tid") - context_stack.append( - ContextStackEntry( - "filename", timeline_event.identifier, metadata, tid - ) - ) - elif timeline_event.marker_type == "node": - # Find the current filename from stack - current_file_metadata = None - tid = timeline_event.event.get("tid") - for ctx_entry in reversed(context_stack): - if ( - ctx_entry.context_type == "filename" - and ctx_entry.tid == tid - ): - current_file_metadata = ctx_entry.metadata - break - - if current_file_metadata: - node_metadata = current_file_metadata.get("node_metadata", {}) - if timeline_event.identifier in node_metadata: - node_meta: Optional[dict] = node_metadata[ - timeline_event.identifier - ] - context_stack.append( - ContextStackEntry( - "node", timeline_event.identifier, node_meta, tid - ) - ) - - case "end": - # Pop from stack - search backwards to find matching context - for i in range(len(context_stack) - 1, -1, -1): - ctx_entry = context_stack[i] - if ( - timeline_event.marker_type == ctx_entry.context_type - and timeline_event.identifier == ctx_entry.identifier - ): - context_stack.pop(i) - break - - case "regular": - # Apply metadata from current context stack - # Find the most specific context (node takes precedence over filename) - # Only augment events with the same tid as the file/node event matched - current_stack_trace = None - current_node_name = None - event_tid = timeline_event.event.get("tid") - - for ctx_entry in reversed(context_stack): - # Only apply metadata from contexts with matching tid - if ctx_entry.tid == event_tid: - if ctx_entry.context_type == "node" and ctx_entry.metadata: - current_stack_trace = ctx_entry.metadata.get( - "stack_trace", "No model stack trace available" - ) - current_node_name = ctx_entry.metadata.get("name", "") - # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes - # if nodes are nested, e.g. in nested graph modules - break - - # Augment the event - if current_stack_trace or current_node_name: - args = timeline_event.event.setdefault("args", {}) - if current_stack_trace: - args["stack_trace"] = current_stack_trace - if current_node_name: - args["node_name"] = current_node_name From d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 1 Nov 2025 16:37:39 -0700 Subject: [PATCH 030/651] make narrow_tensor_symint DDE-free (#166379) https://github.com/pytorch/pytorch/issues/158081 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379 Approved by: https://github.com/Lucaskabela ghstack dependencies: #166361 --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 +++++++++++++ test/test_proxy_tensor.py | 1 - 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6136a6aa8c520..b3fff5a4bb42f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1764,8 +1764,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - int64_t st = start.item(); - return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); + c10::SymInt st = start.item().toSymInt(); + return at::narrow_symint(self, dim, std::move(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index b0dd1ff8fa75d..6cae42d8929da 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - xfail("narrow"), + skip("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..d3f9e415ff944 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,6 +4452,19 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_narrow_with_tensor_start(self): + @torch.compile(backend="inductor", fullgraph=True) + def f(x, start, end): + return torch.narrow(x, 0, start, end) + + x = torch.tensor( + [False], device="cuda:0" if torch.cuda.is_available() else "cpu" + ) + start = torch.tensor(0) + res = f(x, start, 0) + self.assertEqual(res.shape, torch.Size([0])) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b76895a0a91f3..0487995a2d1c5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,7 +1987,6 @@ def f(t): } only_fake_tensor_failures = { - xfail('narrow'), xfail('tensor_split'), } From c1e91bd4c3bef209d43896d18abbf638bed356a9 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 4 Nov 2025 22:55:26 +0000 Subject: [PATCH 031/651] [export] Codemod unittests to use new graph capture API (#166957) Summary: as title. Test Plan: pytest test/functorch/test_aot_joint_with_descriptors.py pytest test/higher_order_ops/test_local_map.py Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166957 Approved by: https://github.com/angelayi, https://github.com/yushangdi --- .../test_aot_joint_with_descriptors.py | 32 +++++++++---------- test/higher_order_ops/test_local_map.py | 21 ++---------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 7949d2bb46cbf..13277fccaea11 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.utils._pytree as pytree from torch._decomp import decomposition_table -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.testing import normalize_gm from torch._functorch._aot_autograd.descriptors import ( BufferAOTInput, @@ -48,17 +48,13 @@ def graph_capture(model, inputs, with_export): gm = model - fake_mode = None + tracing_context = None if with_export: - with ( - torch._dynamo.config.patch(install_free_tensors=True), - fx_traceback.preserve_node_meta(), - ): - # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)(*inputs) - fake_mode = gm.meta.get("fake_mode", None) - - with tracing(TracingContext(fake_mode)): + with fx_traceback.preserve_node_meta(): + gm = dynamo_graph_capture_for_export(model)(*inputs) + tracing_context = gm.meta.get("tracing_context", None) + + with tracing(tracing_context): with ExitStack() as stack: joint_with_descriptors = aot_export_joint_with_descriptors( stack, @@ -325,7 +321,7 @@ def forward(self, x, *, scale): inputs = (torch.randn(4, 3),) kwargs = {"scale": torch.tensor(2.0)} - gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs) + gm = dynamo_graph_capture_for_export(model)(*inputs, **kwargs) with ExitStack() as stack: # Export joint with descriptors @@ -356,8 +352,8 @@ def forward( primals, tangents, ): - primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight') - primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias') + primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight') + primals_2: "f32[2]" # ParamAOTInput(target='linear.bias') primals_3: "f32[4, 3]" # PlainAOTInput(idx=0) primals_4: "f32[]" # PlainAOTInput(idx=1) tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0)) @@ -379,8 +375,8 @@ def forward( transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None return pytree.tree_unflatten([ mul_2, # PlainAOTOutput(idx=0) - transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight')) - as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias')) + transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight')) + as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias')) None, # None None, # None ], self._out_spec)""", @@ -1063,9 +1059,11 @@ def forward(self, x): str(custom_metadata), """\ ('call_function', 'new_empty', {'pp_stage': 0}) +('get_attr', '_tensor_constant0', {'pp_stage': 0}) ('call_function', 'index_put', {'pp_stage': 0}) ('call_function', 'slice_2', {'pp_stage': 0}) ('call_function', 'slice_backward', {'pp_stage': 0}) +('get_attr', '_tensor_constant0_1', {'pp_stage': 0}) ('call_function', 'index', {'pp_stage': 0})""", ) @@ -1082,7 +1080,7 @@ def forward(self, x): model = SimpleLinear() inputs = (torch.randn(4, 3),) - gm = _dynamo_graph_capture_for_export(model)(*inputs) + gm = dynamo_graph_capture_for_export(model)(*inputs) fake_mode = gm.meta.get("fake_mode", None) with tracing(TracingContext(fake_mode)): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 5f37d8e1768d6..9d2870d3b5fdd 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -15,6 +15,7 @@ import torch.fx.traceback as fx_traceback import torch.nn.functional as F from torch import nn +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable from torch._functorch.aot_autograd import aot_export_joint_with_descriptors from torch._subclasses.fake_tensor import FakeTensorMode @@ -51,24 +52,6 @@ def enable_local_map_wrapping(): yield -def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: - from torch._dynamo.functional_export import _dynamo_graph_capture_for_export - from torch.export._trace import _restore_state_dict - - """ - Thin wrapper around graph capture output that restores the - original calling convention and attribute fqn. TODO: - 1) Use bytecode for calling convention instead of pytree for more - seamless UX. - 2) Attach guards - 3) Be more careful about tensor constants names. - """ - with torch._dynamo.config.patch(install_free_tensors=True): - gm = _dynamo_graph_capture_for_export(model)(*inputs) - _restore_state_dict(model, gm) - return gm - - def ap_style_initial_capture( model: torch.nn.Module, inputs_fn: Callable ) -> torch.nn.Module: @@ -90,7 +73,7 @@ def ap_style_initial_capture( enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), ): - torch_ir_with_fqn = _export(model, inputs) + torch_ir_with_fqn = dynamo_graph_capture_for_export(model)(*inputs) unused = ExitStack() joint_with_descriptors = aot_export_joint_with_descriptors( unused, From a96728d1885548cbf696d1c40fc990d1cbe1699b Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 4 Nov 2025 23:35:59 +0000 Subject: [PATCH 032/651] Clarify safety of CUDA graph memory pool sharing across graphs that are replayed in arbtirary order. (#166975) Some users at pytorch conference were asking me about whether it is safe to share a memory pool among cuda graphs that never run concurrently, but may run in arbitrary order, if they don't depend upon each other's output. Even though your capture order doesn't match replay order in this situation, this is safe. However, our documents confusingly said this wasn't allowed. This update is intended to help with that. Since vLLM essentially depends upon this behavior, I call it out specifically. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166975 Approved by: https://github.com/eellison, https://github.com/BoyuanFeng --- docs/source/notes/cuda.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index c7d3a93f73523..caabeb399c722 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -1720,6 +1720,16 @@ and can be used to share memory across graphs as shown:: g1.replay() g2.replay() +It's also safe to share a memory pool across separate graphs that do not depend +on each other's outputs, provided they never run concurrently. +Be aware that replaying one graph can clobber another graph's outputs when +they share a pool, unless :meth:`~torch.Tensor.clone` is called on the outputs +beforehand. +This pattern is frequently used in inference servers that accept variable batch +sizes at runtime. +vLLM is a notable example; see `here `__ +and `here `__. + With :func:`torch.cuda.make_graphed_callables`, if you want to graph several callables and you know they'll always run in the same order (and never concurrently) pass them as a tuple in the same order they'll run in the live workload, and From 0cd809f60c79bb808f2736fa4ac5f602f63caf8f Mon Sep 17 00:00:00 2001 From: Jason Xie Date: Tue, 4 Nov 2025 23:47:11 +0000 Subject: [PATCH 033/651] [inductor][AMD] Filter out invalid Triton Configs for MI350X _scaled_mm (#166442) Summary: Mirrors change done in D81180838 but for inductor. Without this change, running _scaled_mm on MI350X accelerator would crash. Test Plan: HIP_VISIBLE_DEVICES=7 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run mode/opt-amd-gpu -m rocm70 -c fbcode.rocm_arch=mi350 scripts/jchunx/gemm:scaled_mm_microbench -- --csv_file /home/jchunx/scripts/fp8_shapes.csv --backend triton,aten --fast_accum=true 2>&1 | tee ~/logs/scaled_mm.log Reviewed By: bilal Differential Revision: D85694383 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166442 Approved by: https://github.com/bilal --- torch/_inductor/template_heuristics/triton.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 61616d81c2878..8cbbf5073d5ef 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1946,6 +1946,29 @@ def _valid(self, kernel_inputs: KernelInputs) -> bool: return False return True + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter out bad configs for specific hardware. + On AMD MI350X (GFX 9.5+), skip configs with BLOCK_K<=64 due to lack of corresponding MFMA instructions. + """ + + def should_skip_mi350x_config(config: BaseConfig) -> bool: + """Skip config if BLOCK_K<=64 on MI350X (GFX 9.5+)""" + try: + return ( + config.block_k <= 64 + and torch.version.hip is not None + and torch.cuda.get_device_capability() >= (9, 5) + ) + except RuntimeError: + # If no HIP GPUs are available, we can't check device capability + # so we don't skip any configs + return False + + filtered_configs = [c for c in configs if not should_skip_mi350x_config(c)] + return super()._filter_configs(filtered_configs) + # Scaled TMA-specific mixin for scaled MM templates with TMA class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): From 661b63966341d9569829c6ec6799be0757db1a6c Mon Sep 17 00:00:00 2001 From: "Xiangyang (Mark) Guo" Date: Tue, 4 Nov 2025 23:47:12 +0000 Subject: [PATCH 034/651] use_cpp_bmm_template supports more use cases (#165469) Summary: In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous, but the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. This diff specifically checks for contiguity within the 2D matrix of each batch, and enables more uses for cpp bmm template. Differential Revision: D84561331 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165469 Approved by: https://github.com/desertfire --- test/inductor/test_cpu_select_algorithm.py | 26 ++++++++++++++++++++++ torch/_inductor/utils.py | 18 ++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 4e1c48496ebc5..ca520ab66bcc2 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -2697,6 +2697,32 @@ def forward(self, x): self.common(mod, (u,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("bs", (5,)) + @parametrize("Mdim", (16,)) + @parametrize("Kdim", (32,)) + @parametrize("Ndim", (64,)) + @dtypes(torch.float) + def test_bmm_with_broadcasted_mat1(self, bs, Mdim, Kdim, Ndim, dtype): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w): + assert x.dim() == 2, f"Expected x to be 2D, got {x.dim()}D" + x_expanded = x.unsqueeze(0).expand(bs, -1, -1) + return x_expanded @ w + + counters.clear() + u = torch.randn(Mdim, Kdim).to(dtype=dtype) + v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype) + mod = M().to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (u, v), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 2cf915d9e61de..3f8652882af79 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2231,9 +2231,21 @@ def use_cpp_bmm_template( assert isinstance(mat1.layout, Layout) - return ( - use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) - and mat1.layout.is_contiguous() + # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous. + # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. + # So here we specifically check for contiguity within the 2D matrix of each batch. + mat1_size = mat1.layout.size + mat1_stride = mat1.layout.stride + mat1_each_batch_is_contiguous = ( + _use_template_for_cpu(layout) + and mat1.get_dtype() == torch.float32 + and (len(mat1_size) == 3) + and (len(mat1_stride) == 3) + and (mat1_stride[1] == mat1_size[2]) + and (mat1_stride[2] == 1) + ) + return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and ( + mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous ) From 4b12c0344d0b1a6536a2659c4e498c805efdc1f1 Mon Sep 17 00:00:00 2001 From: Karhou Tam Date: Tue, 4 Nov 2025 23:53:56 +0000 Subject: [PATCH 035/651] Add default `.github/copilot-instructions.md` and item in `.gitignore` for allowing local changes (#166864) Fixes [#166850](https://github.com/pytorch/pytorch/issues/166850) - Create a default `.github/copilot-instructions.md` file (used Claude Sonnet 4.5 in Copilot). - Add `.github/copilot-instructions.md` to the `.gitignore` file. The prompt used is below, which is preset by Copilot: ``` Analyze this codebase to generate or update `.github/copilot-instructions.md` for guiding AI coding agents. Focus on discovering the essential knowledge that would help an AI agents be immediately productive in this codebase. Consider aspects like: - The "big picture" architecture that requires reading multiple files to understand - major components, service boundaries, data flows, and the "why" behind structural decisions - Critical developer workflows (builds, tests, debugging) especially commands that aren't obvious from file inspection alone - Project-specific conventions and patterns that differ from common practices - Integration points, external dependencies, and cross-component communication patterns Source existing AI conventions from `**/{.github/copilot-instructions.md,AGENT.md,AGENTS.md,CLAUDE.md,.cursorrules,.windsurfrules,.clinerules,.cursor/rules/**,.windsurf/rules/**,.clinerules/**,README.md}` (do one glob search). Guidelines (read more at https://aka.ms/vscode-instructions-docs): - If `.github/copilot-instructions.md` exists, merge intelligently - preserve valuable content while updating outdated sections - Write concise, actionable instructions (~20-50 lines) using markdown structure - Include specific examples from the codebase when describing patterns - Avoid generic advice ("write tests", "handle errors") - focus on THIS project's specific approaches - Document only discoverable patterns, not aspirational practices - Reference key files/directories that exemplify important patterns Update `.github/copilot-instructions.md` for the user, then ask for feedback on any unclear or incomplete sections to iterate. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166864 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .github/copilot-instructions.md | 125 ++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000000..06c3f32abd5e1 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,125 @@ +# PyTorch Copilot Instructions + +This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively. + +## Architecture Overview + +### Core Components + +- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality +- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd + - `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse) + - `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry +- **torch/** - Python bindings and public API + - `torch/csrc/` - C++ Python bindings (hand-written and generated) + - `torch/csrc/autograd/` - Reverse-mode automatic differentiation + - `torch/csrc/jit/` - TorchScript JIT compiler +- **torchgen/** - Code generation tooling that reads `native_functions.yaml` +- **tools/** - Build scripts, autograd derivatives, code generation + +### The Code Generation Workflow + +**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file: +1. Declares operator signatures, variants (function/method), and dispatch behavior +2. Gets processed by `torchgen/` to generate C++/Python bindings +3. Produces headers in `build/aten/src/ATen/` during compilation + +Example entry structure: +```yaml +- func: my_op(Tensor self, Scalar alpha=1) -> Tensor + variants: function, method + dispatch: + CPU: my_op_cpu + CUDA: my_op_cuda +``` + +After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`). + +## Development Workflows + +### Building from Source + +**Never run `setup.py` directly** - use pip with editable install: +```bash +python -m pip install --no-build-isolation -v -e . +``` + +Speed up builds: +- `DEBUG=1` - Debug symbols with `-g -O0` +- `USE_CUDA=0` - Skip CUDA compilation +- `BUILD_TEST=0` - Skip C++ test binaries +- Install `ninja` (`pip install ninja`) for faster builds +- Use `ccache` for incremental compilation caching + +Rebuild specific targets: `(cd build && ninja )` + +### Testing + +**Critical**: DO NOT run entire test suites. Run specific tests only: +```bash +python test/test_torch.py TestTorch.test_specific_case +``` + +**Test structure**: All tests use `torch.testing._internal.common_utils`: +```python +from torch.testing._internal.common_utils import run_tests, TestCase + +class TestFeature(TestCase): + def test_something(self): + # Use self.assertEqual for tensor comparisons + pass + +if __name__ == "__main__": + run_tests() +``` + +**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file. + +### Linting + +Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes) + +## Project-Specific Conventions + +### Memory and Storage +- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs) +- CUDA device info lives in storage objects + +### Python-C++ Integration (`torch/csrc/`) +- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors +- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr` +- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion + +### Dispatch System +- PyTorch uses operator dispatch to route calls to backend-specific kernels +- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops +- See `aten/src/ATen/native/README.md` for dispatch keyword guidance + +## Git Workflow (AI Agent Specific) + +When preparing PRs from this environment: +```bash +git stash -u +git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch +git stash pop +# Resolve conflicts if necessary +``` + +## Common Gotchas + +1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml` +2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI +3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux) +4. **No internet access** - DO NOT attempt to install dependencies during development + +## Key Files Reference + +- `AGENTS.md` - Instructions specific to AI coding agents +- `CONTRIBUTING.md` - Comprehensive human contributor guide +- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript) +- `aten/src/ATen/native/README.md` - Operator implementation guide +- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd + +## Performance Debugging + +Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation. From 7eefcfb1db5995739a2614f368594cb266d33173 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 4 Nov 2025 23:54:15 +0000 Subject: [PATCH 036/651] [BE][Typing][Dynamo] Type torch/_dynamo/variables/ctx_manager.py (#166878) Provides type coverage to torch/_dynamo/variables/ctx_manager.py Coverage report: `mypy torch/_dynamo/variables/ctx_manager.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1541 lines and 144 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/166878 Approved by: https://github.com/Skylion007 --- torch/_C/_functorch.pyi | 2 + torch/_dynamo/polyfills/pytree.py | 16 +- torch/_dynamo/symbolic_convert.py | 7 +- torch/_dynamo/variables/ctx_manager.py | 571 ++++++++++++++-------- torch/_dynamo/variables/streams.py | 16 +- torch/_dynamo/variables/torch.py | 1 + torch/_dynamo/variables/torch_function.py | 3 +- 7 files changed, 387 insertions(+), 229 deletions(-) diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index c23240e13170a..a35befcad392d 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -5,6 +5,8 @@ from torch import Tensor # Defined in torch/csrc/functorch/init.cpp +def set_inplace_requires_grad_allowed(allowed: bool) -> None: ... +def get_inplace_requires_grad_allowed() -> bool: ... def _set_dynamic_layer_keys_included(included: bool) -> None: ... def get_unwrapped(tensor: Tensor) -> Tensor: ... def is_batchedtensor(tensor: Tensor) -> bool: ... diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index f9bdc0cce4a00..d86fe054b2ebc 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -64,7 +64,7 @@ def _(*args: Any, **kwargs: Any) -> bool: del __func del __name - @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] def tree_is_leaf( tree: PyTree, /, @@ -79,7 +79,7 @@ def tree_is_leaf( return True return False - @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) + @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] def tree_iter( tree: PyTree, /, @@ -110,7 +110,7 @@ def tree_iter( __all__ += ["tree_iter"] - @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] def tree_leaves( tree: PyTree, /, @@ -451,7 +451,7 @@ def treespec_dict( dict, metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -507,7 +507,7 @@ def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: type(node), metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) # type: ignore[arg-type] @@ -557,7 +557,7 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: __all__ += ["tree_unflatten"] - @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map( func: Callable[..., Any], tree: PyTree, @@ -578,7 +578,7 @@ def tree_map( __all__ += ["tree_map"] - @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map_( func: Callable[..., Any], tree: PyTree, @@ -600,7 +600,7 @@ def tree_map_( __all__ += ["tree_map_"] - _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr] + _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined] @substitute_in_graph( # type: ignore[arg-type] _none_unflatten, diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9d0d87c5f8a06..53ec0ee412849 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -434,12 +434,15 @@ def resume_fn(self) -> ReenterWith: else: return ReenterWith(self.stack_index - 1) - def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None: + def exit( + self, tx: InstructionTranslatorBase, is_graph_break: bool + ) -> VariableTracker | None: assert self.with_context is not None if ( is_graph_break and self.with_context.exit_on_graph_break() ) or not is_graph_break: return self.with_context.exit(tx) # type: ignore[arg-type] + return None class SpeculationLogDivergence(AssertionError): @@ -3860,7 +3863,7 @@ def enter_ctx( else: self.block_stack.append(BlockStackEntry(inst, target, len(self.stack))) - return ctx.enter(self) + return ctx.enter(self) # type: ignore[arg-type] @staticmethod def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn: diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 0502c58a78420..4eac189b65fdd 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This file contains a collection of context manager classes used by Dynamo for tracking and managing various PyTorch runtime states during graph compilation. These context @@ -23,8 +21,9 @@ import inspect import sys import warnings +from collections.abc import Callable, Sequence from contextlib import ExitStack -from typing import TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union import torch._C from torch._guards import Guard @@ -67,35 +66,43 @@ class ContextWrappingVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Optional[Any] = None, **kwargs: Any + ) -> None: super().__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values - def enter(self, tx): - self._call_func(tx, self.target_values) + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + if hasattr(self, "_call_func"): + self._call_func(tx, self.target_values) self.set_cleanup_hook(tx) return variables.ConstantVariable.create(None) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): + def set_cleanup_hook( + self, tx: "InstructionTranslator", fn: Optional[Callable[..., Any]] = None + ) -> None: if fn is None: - def fn(): - self._call_func(tx, self.initial_values) + def fn() -> None: + if hasattr(self, "_call_func"): + self._call_func(tx, self.initial_values) - self.cleanup_fn = fn + self.cleanup_fn: Optional[Callable[..., Any]] = fn tx.output.add_cleanup_hook(self.cleanup) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -103,18 +110,18 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.extend_output([codegen.create_load_const(val) for val in target_values]) codegen.extend_output(create_call_function(len(target_values), False)) - def module_name(self): + def module_name(self) -> str: raise NotImplementedError("module_name called on base") - def fn_name(self): + def fn_name(self) -> str: raise NotImplementedError("fn_name called on base") def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert len(args) == 1 assert isinstance( args[0], @@ -128,28 +135,27 @@ def call_function( if isinstance(args[0], NestedUserFunctionVariable): return WrappedNestedUserFunctionVariable(args[0], self) - - if isinstance(args[0], SkipFunctionVariable): + elif isinstance(args[0], SkipFunctionVariable): return WrappedSkipFunctionVariable(args[0], self) - - if isinstance(args[0], UserMethodVariable): + elif isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) - - if isinstance(args[0], UserFunctionVariable): + elif isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + else: + raise AssertionError("Unexpected arg type") - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return True - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True - def cleanup(self): + def cleanup(self) -> None: if self.cleanup_fn is not None: self.cleanup_fn() self.cleanup_fn = None - def cleanup_assert(self): + def cleanup_assert(self) -> None: assert self.cleanup_fn, "multiple exits?" self.cleanup() @@ -157,7 +163,7 @@ def cleanup_assert(self): class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are # python constants. Which might not always be the case here. - def __init__(self, cm_obj, **kwargs) -> None: + def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None: assert cm_obj is not None super().__init__( value=cm_obj, @@ -166,44 +172,46 @@ def __init__(self, cm_obj, **kwargs) -> None: ) self.cm_obj = cm_obj - def module_name(self): + def module_name(self) -> str: return self.cm_obj.__module__ - def fn_name(self): + def fn_name(self) -> str: return type(self.cm_obj).__name__ - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__enter__") return variables.UserMethodVariable( - self.cm_obj.__enter__.__func__, + self.cm_obj.__enter__.__func__, # type: ignore[attr-defined] self, source=source, ).call_function(tx, [], {}) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__exit__") x = variables.UserMethodVariable( - self.cm_obj.__exit__.__func__, + self.cm_obj.__exit__.__func__, # type: ignore[attr-defined] self, source=source, - ).call_function(tx, args, {}) + ).call_function(tx, list(args), {}) tx.active_generic_context_managers.pop() return x - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return False - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): - def __init__(self, ctx_manager_vt, mod): + def __init__(self, ctx_manager_vt: ContextWrappingVariable, mod: Any) -> None: self.cm_vt = ctx_manager_vt self.mod = mod # We don't call super().__init__() because we're delegating most methods to cm_vt - def enter(self, tx: "InstructionTranslator"): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: # Custom enter implementation with side effects self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() @@ -212,7 +220,9 @@ def enter(self, tx: "InstructionTranslator"): tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) return self.cm_vt.enter(tx) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Custom exit implementation with side effects x = self.cm_vt.exit(tx, *args) tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) @@ -220,7 +230,7 @@ def exit(self, tx: "InstructionTranslator", *args): return x # Forward all other method calls to self.cm_vt - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This will be called for any attribute not explicitly defined in this class return getattr(self.cm_vt, name) @@ -229,14 +239,16 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requires grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "GradInplaceRequiresGradCtxManagerVariable": return GradInplaceRequiresGradCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [enabled] = self.target_values self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() torch._C._functorch.set_inplace_requires_grad_allowed(enabled) @@ -254,7 +266,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -269,14 +283,16 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "TemporarilyPopInterpreterStackCtxManagerVariable": return TemporarilyPopInterpreterStackCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.saved = torch._C._functorch.pop_dynamic_layer_stack() self.set_cleanup_hook( tx, @@ -290,7 +306,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -309,10 +327,12 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a jvp # call from eager that calls the compiled function, as the jvp levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "JvpIncrementNestingCtxManagerVariable": var = JvpIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -320,7 +340,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() self.set_cleanup_hook( @@ -334,7 +354,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(jvp_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} @@ -346,14 +368,16 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable): """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "SetFwdGradEnabledContextManager": return SetFwdGradEnabledContextManager( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [mode] = self.target_values self.prev_state = torch._C._is_fwd_grad_enabled() torch._C._set_fwd_grad_enabled(mode) @@ -369,7 +393,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -383,17 +409,17 @@ def exit(self, tx: "InstructionTranslator", *args): class DualLevelContextManager(ContextWrappingVariable): """Represents torch.autograd.forward_ad.dual_level ctx manager""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create(tx: "InstructionTranslator", **kwargs: Any) -> "DualLevelContextManager": return DualLevelContextManager( target_values=None, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) self.new_level = torch.autograd.forward_ad.enter_dual_level() self.set_cleanup_hook( @@ -407,7 +433,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(self.new_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -426,10 +454,12 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a grad # call from eager that calls the compiled function, as the grad levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "GradIncrementNestingCtxManagerVariable": var = GradIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -437,7 +467,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) grad_level = torch._C._functorch._grad_increment_nesting() self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) @@ -449,7 +479,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(grad_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._grad_decrement_nesting, (), {} @@ -461,19 +493,29 @@ class CatchWarningsCtxManagerVariable(ContextWrappingVariable): """Delay a call to warnings.catch_warnings""" @staticmethod - def create(tx: "InstructionTranslator", catch_warnings_args): + def create( + tx: "InstructionTranslator", catch_warnings_args: dict[str, VariableTracker] + ) -> "CatchWarningsCtxManagerVariable": return CatchWarningsCtxManagerVariable( catch_warnings_args=catch_warnings_args, target_values=None, initial_values=None, ) - def __init__(self, catch_warnings_args, **kwargs) -> None: + def __init__( + self, + catch_warnings_args: dict[str, VariableTracker], + target_values: Optional[Any] = None, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: assert isinstance(catch_warnings_args, dict), catch_warnings_args - super().__init__(**kwargs) + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) self.catch_warnings_args = catch_warnings_args - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: kwargs = { k: v.as_python_constant() for k, v in self.catch_warnings_args.items() } @@ -481,7 +523,7 @@ def enter(self, tx): self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) return variables.ConstantVariable.create(ctx_val.__enter__()) - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) cg.foreach(self.catch_warnings_args.values()) keys = tuple(self.catch_warnings_args.keys()) @@ -496,10 +538,14 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a vmap # call from eager that calls the compiled function, as the vmap levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", + target_values: Sequence[VariableTracker], + **kwargs: Any, + ) -> "VmapIncrementNestingCtxManagerVariable": var = VmapIncrementNestingCtxManagerVariable( target_values=target_values, initial_values=None, @@ -507,7 +553,7 @@ def create(tx: "InstructionTranslator", target_values, **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) batch_size, randomness = self.target_values if isinstance(batch_size, variables.SymNodeVariable): @@ -527,7 +573,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(vmap_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -541,10 +589,15 @@ def exit(self, tx: "InstructionTranslator", *args): class GradModeVariable(ContextWrappingVariable): """represents torch.{no_grad,enable_grad,set_grad_mode}()""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): + def create( + tx: "InstructionTranslator", + target_value: Any, + initialized: bool = False, + **kwargs: Any, + ) -> "GradModeVariable": var = GradModeVariable( target_values=[target_value], initial_values=[torch.is_grad_enabled()], @@ -555,31 +608,37 @@ def create(tx: "InstructionTranslator", target_value, initialized=False, **kwarg return var def __init__( - self, target_values, initial_values=None, initialized=True, **kwargs + self, + target_values: Any, + initial_values: Optional[Sequence[bool]] = None, + initialized: bool = True, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self._call_func(tx, self.initial_values) return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # Coalesce grad mode mutations @@ -589,16 +648,18 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_grad_enabled(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "set_grad_enabled" class InferenceModeVariable(ContextWrappingVariable): @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Any, **kwargs: Any + ) -> "InferenceModeVariable": var = InferenceModeVariable( [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs ) @@ -606,9 +667,9 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[bool] = None, + **kwargs: Any, ) -> None: if initial_values is None: # This must be called here since function defaults are evaluated at import time @@ -616,9 +677,10 @@ def __init__( super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -626,8 +688,9 @@ def exit(self, tx: "InstructionTranslator", *args): (self.proxy,), {}, ) + return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: disabled_inference_mode_forcibly = False if ( torch._dynamo.config.fake_tensor_disable_inference_mode @@ -642,7 +705,7 @@ def enter(self, tx): else: ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) - def cleanup_hook(): + def cleanup_hook() -> None: if disabled_inference_mode_forcibly: torch._C._set_grad_enabled(prior) else: @@ -655,11 +718,12 @@ def cleanup_hook(): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "inference_mode" @@ -667,7 +731,9 @@ class CUDADeviceVariable(ContextWrappingVariable): """represents torch.cuda.device""" @staticmethod - def create(tx: "InstructionTranslator", device, **kwargs): + def create( + tx: "InstructionTranslator", device: Any, **kwargs: Any + ) -> "CUDADeviceVariable": var = CUDADeviceVariable( target_values=[torch.cuda._get_device_index(device, optional=True)], initial_values=None, @@ -677,16 +743,17 @@ def create(tx: "InstructionTranslator", device, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[Any] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -696,7 +763,7 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(False) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: prev_idx = torch.cuda._exchange_device(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) self.proxy = tx.output.create_node( @@ -705,21 +772,24 @@ def enter(self, tx): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.cuda" - def fn_name(self): + def fn_name(self) -> str: return "device" class TorchFunctionDisableVariable(ContextWrappingVariable): """represents whether torch function overrides are enabled or not""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "TorchFunctionDisableVariable": var = TorchFunctionDisableVariable( target_values=[], initial_values=[], @@ -728,10 +798,14 @@ def create(tx: "InstructionTranslator", **kwargs): return var def __init__( - self, target_values, initial_values=None, only_subclass=True, **kwargs + self, + target_values: Sized, + initial_values: Optional[Sized] = None, + only_subclass: bool = True, + **kwargs: Any, ) -> None: assert len(target_values) == 0 - assert len(initial_values) == 0 + assert initial_values is not None and len(initial_values) == 0 from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() @@ -748,10 +822,14 @@ def __init__( ) install_guard(self._guards_singleton) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): - if fn is None: + def set_cleanup_hook( + self, + tx: "InstructionTranslator", + cleanup_fn: Optional[Callable[..., Any]] = None, + ) -> None: + if cleanup_fn is None: - def fn(): + def cleanup_fn() -> None: tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( self.initial_torch_function_subclass_enabled ) @@ -760,19 +838,19 @@ def fn(): self.initial_torch_function_subclass_enabled ) - self.cleanup_fn = fn + self.cleanup_fn = cleanup_fn tx.output.add_cleanup_hook(self.cleanup) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sized) -> None: assert len(values) == 0 tx.symbolic_torch_function_state.torch_function_subclass_enabled = False if not self.only_subclass: tx.symbolic_torch_function_state.torch_function_mode_enabled = False - def module_name(self): + def module_name(self) -> str: return "torch._C" - def fn_name(self): + def fn_name(self) -> str: if self.only_subclass: return "DisableTorchFunctionSubclass" return "DisableTorchFunction" @@ -782,11 +860,14 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable): """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" _guards_singleton = Guard( - GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS + GlobalStateSource(), + GuardBuilder.DETERMINISTIC_ALGORITHMS, # type: ignore[arg-type] ) @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: bool, **kwargs: Any + ) -> "DeterministicAlgorithmsVariable": var = DeterministicAlgorithmsVariable( target_values=[target_value], initial_values=[torch.are_deterministic_algorithms_enabled()], @@ -796,16 +877,21 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[bool], + initial_values: Optional[Sequence[bool]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 value = values[0] tx.output.create_node( @@ -813,10 +899,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_deterministic_algorithms(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "use_deterministic_algorithms" @@ -824,7 +910,9 @@ class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): """represents torch.autograd.graph.disable_saved_tensors_hook.""" @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Optional[str], **kwargs: Any + ) -> "DisabledSavedTensorsHooksVariable": var = DisabledSavedTensorsHooksVariable( target_values=[target_value], initial_values=[ @@ -836,15 +924,22 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Optional[str]], + initial_values: Optional[Sequence[Optional[str]]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func( + self, tx: "InstructionTranslator", values: Sequence[Optional[str]] + ) -> None: assert len(values) == 1 value = values[0] if value is not None: @@ -865,16 +960,20 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._autograd._saved_tensors_hooks_enable() - def module_name(self): + def module_name(self) -> str: return "torch.autograd.graph" - def fn_name(self): + def fn_name(self) -> str: return "disable_saved_tensors_hooks" class AutocastModeVariable(ContextWrappingVariable): @staticmethod - def create(func, args, kwargs): + def create( + func: torch.amp.autocast_mode.autocast, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> "AutocastModeVariable": assert func in [ torch.amp.autocast_mode.autocast, torch.cuda.amp.autocast, @@ -905,30 +1004,37 @@ def create(func, args, kwargs): var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Any], + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", torch.amp._exit_autocast, (self.proxy,), {} ) return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: ctx = torch.amp._enter_autocast(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) self.proxy = tx.output.create_node( "call_function", torch.amp._enter_autocast, (*self.target_values,), {} ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.amp.autocast_mode" - def fn_name(self): + def fn_name(self) -> str: return "autocast" @@ -937,20 +1043,22 @@ class NullContextVariable(ContextWrappingVariable): This class represents Python contextlib.nullcontext. """ - def __init__(self, target_values=None, **kwargs) -> None: + def __init__(self, target_values: Optional[Any] = None, **kwargs: Any) -> None: super().__init__(target_values=target_values, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: none = variables.ConstantVariable.create(None) return self.target_values if self.target_values else none - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" @@ -963,22 +1071,24 @@ class ProfilerContextVariable(ContextWrappingVariable): than `None`, per implementation of the torch objects. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(target_values=None, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return self - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.profiler object escaped from compiled region", context=str(self), @@ -995,27 +1105,37 @@ class PreserveVersionContextVariable(ContextWrappingVariable): """ @staticmethod - def _create_lambda_from_tensors(tx, tensors): + def _create_lambda_from_tensors( + tx: "InstructionTranslator", + tensors: VariableTracker, + ) -> "PreserveVersionContextVariable": if isinstance(tensors, variables.TensorVariable): versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in [tensors]] ) - tensors = variables.TupleVariable([tensors]) + tensors_tuple = variables.TupleVariable([tensors]) else: + assert isinstance(tensors, variables.TupleVariable) versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in tensors.items] ) - return PreserveVersionContextVariable(tensors, versions) + tensors_tuple = tensors + return PreserveVersionContextVariable(tensors_tuple, versions) @staticmethod - def constructor(tx): + def constructor(tx: "InstructionTranslator") -> VariableTracker: return variables.LambdaVariable( lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( tx, tensors ) ) - def __init__(self, tensors, prev_versions, **kwargs) -> None: + def __init__( + self, + tensors: VariableTracker, + prev_versions: VariableTracker, + **kwargs: Any, + ) -> None: kwargs.setdefault("target_values", None) super().__init__(**kwargs) self.tensors = tensors @@ -1028,17 +1148,19 @@ def __init__(self, tensors, prev_versions, **kwargs) -> None: ): self.prev_versions = variables.TupleVariable([self.prev_versions]) - def enter(self, tx): - pass + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: from ..tensor_version_op import _unsafe_set_version_counter return variables.TorchInGraphFunctionVariable( _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1053,10 +1175,15 @@ def reconstruct(self, codegen: "PyCodegen"): class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): + def create( + tx: "InstructionTranslator", + param_group_var: Any, + target_value: Any, + **kwargs: Any, + ) -> "FSDPParamGroupUseTrainingStateVariable": var = FSDPParamGroupUseTrainingStateVariable( param_group_var=param_group_var, target_values=[target_value], @@ -1066,7 +1193,11 @@ def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs) return var def __init__( - self, param_group_var, target_values, initial_values=None, **kwargs + self, + param_group_var: Any, + target_values: Sequence[Any], + initial_values: Optional[Sequence[Any]] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1074,24 +1205,27 @@ def __init__( self.param_group_var = param_group_var install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): - self._call_func(tx, self.initial_values) + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): - self._call_func(tx, self.initial_values) # undo eager initialization + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # undo eager initialization + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[Any]) -> None: assert len(values) == 1 value = values[0] if self.param_group_var.value._training_state != value: @@ -1106,10 +1240,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) self.param_group_var.value._training_state = value - def module_name(self): + def module_name(self) -> str: return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" - def fn_name(self): + def fn_name(self) -> str: return "use_training_state" @@ -1117,7 +1251,12 @@ class SDPAKernelVariable(ContextWrappingVariable): """represents torch.nn.attention.sdpa_kernel""" @staticmethod - def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): + def create( + tx: "InstructionTranslator", + backends: Any, + set_priority: bool = False, + **kwargs: Any, + ) -> "SDPAKernelVariable": if isinstance(backends, torch.nn.attention.SDPBackend): backends = [backends] var = SDPAKernelVariable( @@ -1131,9 +1270,9 @@ def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): def __init__( self, target_values: list[torch.nn.attention.SDPBackend], - initial_values=None, + initial_values: Any = None, set_priority: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1141,7 +1280,10 @@ def __init__( self.set_priority = set_priority @staticmethod - def _backends_to_nodes(tx, backends): + def _backends_to_nodes( + tx: "InstructionTranslator", + backends: list[Any], + ) -> list[Any]: # convert to/from string in order to bake the backend into FX graph nodes = [ tx.output.create_node( @@ -1154,7 +1296,7 @@ def _backends_to_nodes(tx, backends): ] return nodes - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( with_priority=self.set_priority ) @@ -1176,7 +1318,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() arg = self._backends_to_nodes(tx, self.prev_backends) tx.output.create_node( @@ -1187,12 +1331,12 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.nn.attention" # use a private version of sdpa_kernel that accepts variadic arguments # since dynamo reconstructs the contents of target_values one-by-one - def fn_name(self): + def fn_name(self) -> str: return "_sdpa_kernel_variadic" @@ -1206,12 +1350,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable): __exit__ method (instead of tracing). """ - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Any = None, **kwargs: Any + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx, *args): + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Run the annotation ctx manager in eager. Also ensure that # preserve_node_meta context manager is setup. This is important to pass # on the metadata to the create_proxy nodes. @@ -1221,13 +1369,13 @@ def enter(self, tx, *args): self.set_cleanup_hook(tx, lambda: stack.close()) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.fx.traceback" - def fn_name(self): + def fn_name(self) -> str: return "annotate" - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.fx.traceback.annotate escaped from compiled region", context=str(self), @@ -1243,50 +1391,52 @@ class DynamoConfigPatchVariable(ContextWrappingVariable): # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness # (though it may affect tracing behavior) - def __init__(self, target_values, **kwargs) -> None: - target_values = tuple(target_values.items()) - super().__init__(target_values=(target_values,), initial_values=None, **kwargs) - self.initial_values = {} - for key, _ in target_values: - self.initial_values[key] = torch._dynamo.config.__getattr__(key) - self.initial_values = (tuple(self.initial_values.items()),) - - def _call_func(self, tx: "InstructionTranslator", values): + def __init__(self, target_values: dict[str, Any], **kwargs: Any) -> None: + target_values_tuple = tuple(target_values.items()) + super().__init__( + target_values=(target_values_tuple,), initial_values=None, **kwargs + ) + initial_values_dict = {} + for key, _ in target_values_tuple: + initial_values_dict[key] = torch._dynamo.config.__getattr__(key) # type: ignore[attr-defined] + self.initial_values = (tuple(initial_values_dict.items()),) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # manually patch dynamo config for key, val in value: - torch._dynamo.config.__setattr__(key, val) + torch._dynamo.config.__setattr__(key, val) # type: ignore[attr-defined] # No need to keep track of global side effects because # dynamo will properly restore this context manager for # unsupported instructions and continuation functions. # Dynamo config also should not affect the semantics of the compiled graph. - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "patch_dynamo_config" class ErrorOnGraphBreakVariable(ContextWrappingVariable): """represents torch._dynamo.error_on_graph_break""" - def __init__(self, error_on_graph_break, **kwargs) -> None: + def __init__(self, error_on_graph_break: bool, **kwargs: Any) -> None: super().__init__( target_values=(error_on_graph_break,), initial_values=(_get_error_on_graph_break(),), **kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 _set_error_on_graph_break(values[0]) - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "error_on_graph_break" @@ -1294,7 +1444,7 @@ class WithEnterFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.ctx = ctx @@ -1302,16 +1452,17 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not args assert not kwargs # NOTE: we assume that the instruction immediately after the current CALL instruction # is the first instruction of the block. + # pyrefly: ignore [bad-argument-type] return tx.enter_ctx(self.ctx, tx.current_instruction) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: try: type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" except NotImplementedError: @@ -1339,8 +1490,8 @@ class WithExitFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - target, - **kwargs, + target: Any, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance( @@ -1352,27 +1503,29 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. - self.ctx.reconstruct_type(codegen) + self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined] if codegen.tx.output.partial_convert: if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): codegen.append_output(create_instruction("SWAP", arg=2)) + # We rely on classes subtyping `GenericContextWrappingVariable` + # to implement these fns and have these attributes codegen.extend_output( - [codegen.create_load_const(val) for val in self.ctx.target_values] + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type] ) codegen.extend_output( - create_call_function(len(self.ctx.target_values), False) + create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type] ) codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP")) diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fbc0eed3a99ff..c353181eb8029 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -116,11 +116,7 @@ def create( **kwargs, ) - def __init__( - self, - stream: Optional["StreamVariable"], - **kwargs: dict[str, Any], - ) -> None: + def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: self.stream = stream super().__init__( target_values={"stream": self.get_stream().user_object_index}, @@ -129,14 +125,16 @@ def __init__( ) def enter( - self, tx: "InstructionTranslator", *args: tuple[Any] - ) -> "VariableTracker": + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are entering the target, and leaving the initial stream tx.symbolic_stream_state.enter_stream(self.get_stream()) return super().enter(tx) - def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker": + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are leaving the target, and entering the initial stream tx.symbolic_stream_state.exit_stream() @@ -182,7 +180,7 @@ def call_method( name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: assert hasattr(self.value, name), f"no stream method found named {name}" from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c2e3df8e4adce..be28fe9269f44 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -408,6 +408,7 @@ def call_function( torch.cuda.amp.autocast, torch.cpu.amp.autocast, ): + # pyrefly: ignore [bad-argument-type] return AutocastModeVariable.create(self.value, args, kwargs) elif self.value in ( # NOTE any class added here must align with the semantic diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 378e9258459f5..fa8412146a427 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -164,7 +164,8 @@ def __init__( if value is not None: super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code + # needed for BC with calling enter from CM code + self.cm_obj = value # type: ignore[assignment] self.source = source # type: ignore[assignment] def reconstruct(self, codegen: "PyCodegen") -> None: From 4271ffe91849335ffbcc2014c948694f8ec107fd Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 5 Nov 2025 00:20:24 +0000 Subject: [PATCH 037/651] don't produce invalid grid configs (#166974) Proper fix for #164048, fixes gather too, reverts #164049 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166974 Approved by: https://github.com/eqy --- aten/src/ATen/native/cuda/IndexKernel.cu | 15 +++------------ aten/src/ATen/native/cuda/IndexKernelUtils.cu | 15 +++++++++------ test/test_cuda.py | 7 ------- test/test_scatter_gather_ops.py | 8 ++++++-- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 927af661396cd..db85f62c8d124 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -74,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co char* const out_ptr = static_cast(iter.data_ptr(0)); char* const in_ptr = static_cast(iter.data_ptr(1)); - if (is_gather_like && num_indices==1) { const size_t element_size = iter.element_size(0); constexpr size_t alignment = 16; @@ -84,16 +82,9 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co auto ind_dim_size = index_size[0]; auto inp_stride_bytes = index_stride[0]; auto out_stride_bytes = iter.strides(0)[1]; - // avoid grid overflow in the fast kernel - const int64_t vec_chunks = ceil_div(slice_size, alignment); - const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd); - const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - // if it's an eligible grid we use the fast path, otherwise default to slower path - if (blocks_per_slice_upper <= max_grid_y) { - at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, - slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); - return; - } + at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, + slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); + return; } } diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index 8343c60418952..1e998251dd7be 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -13,11 +13,12 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, if (allow_neg_indices) { ind = (ind < 0) ? ind + ind_dim_size : ind; } - CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind); - int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits - if (off >= slice_size) return; - auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); - at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"); + // off is guaranteed to be within int32 limits + for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) { + auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); + at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + } } @@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int auto num_threads = at::round_up( at::ceil_div(slice_size_in_bytes, Alignment), static_cast(C10_WARP_SIZE)); - dim3 grid = {static_cast(num_ind), static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1}; + uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + grid_y = std::min(static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y); + dim3 grid = {static_cast(num_ind), grid_y, 1}; auto block = std::min(max_num_threads, num_threads); vectorized_gather_kernel<<>>(out, inp, idx, num_ind, slice_size_in_bytes, ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices); diff --git a/test/test_cuda.py b/test/test_cuda.py index 00c3b00d6049c..329261fba7d3a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1474,13 +1474,6 @@ def test_huge_index(self): res_cpu = src.cpu()[idx.cpu()] self.assertEqual(res.cpu(), res_cpu) - def test_fast_index_overflow(self): - src = torch.randint(0, 20, (4, 87, 1056, 736), device="cuda") - indices = torch.tensor([True, False, False, True], device="cuda") - res = src[indices] - res_cpu = src.cpu()[indices.cpu()] - self.assertEqual(res.cpu(), res_cpu) - def test_randint_randomness_for_large_range(self) -> None: # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox # offsets were not calculated correctly, resulting in reused random states. diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index ba967c142f1e7..96768f34affb0 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -6,7 +6,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM) + (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM, serialTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, toleranceOverride, tol,) @@ -65,10 +65,12 @@ def test_gather(self, device, dtype): actual = torch.gather(src, 2, idx) self.assertEqual(actual, expected, atol=0, rtol=0) + @serialTest() @dtypes(torch.int8, torch.bfloat16) def test_gather_large(self, device, dtype): # test larger shapes to check vectorized implementation - for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)): + for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100), (4, 4, 16384 * 8192)): + torch.cuda.empty_cache() src = make_tensor((m, k), device=device, dtype=dtype) alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype) discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src) @@ -111,6 +113,8 @@ def test_gather_large(self, device, dtype): self.assertEqual(res_ind, ref, atol=0, rtol=0) res_gather = torch.gather(misaligned1, dim=dim, index=ind) self.assertEqual(res_gather, ref, atol=0, rtol=0) + del src, alloc0, alloc1, alloc2 + del discontig, misaligned, misaligned1 # test gather along 1st dim that can accidentally trigger fast path # because due to index dimension in the gather dim being 1 # an unexpected squashing in tensorIterator happens From f2fbc81c506d4497d1505a7c27d949e4fc4da8d6 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 13:58:39 -0800 Subject: [PATCH 038/651] [RFC] Add experimental Pallas TorchInductor backend (#166822) Very simple Pallas TorchInductor backend Given ``` import torch def f(x, y): return x.sin() + y torch._inductor.config.cuda_backend="pallas" x = torch.randn(4).cuda() y = torch.randn(4).cuda() compiled = torch.compile(f, backend="inductor", fullgraph=True) torch.testing.assert_close(compiled(x, y), f(x, y)) ``` it outputs ``` import torch import jax import jax.numpy as jnp from jax.experimental import pallas as pl from torch.utils import dlpack as torch_dlpack def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0): tmp0 = in_ptr0[...] tmp1 = jnp.sin(tmp0) tmp2 = in_ptr1[...] tmp3 = tmp1 + tmp2 out_ptr0[...] = tmp3 def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None): # Convert Torch -> JAX for inputs in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0)) in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1)) # Prepare output spec from PyTorch tensor # Map PyTorch dtype to JAX dtype string _torch_dtype_to_jax = { torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16, torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8, torch.uint8: jnp.uint8, torch.bool: jnp.bool_, } out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype]) compiled = pl.pallas_call( lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs), out_shape=out_spec, grid=(1,), ) res = compiled(in_ptr0_jax, in_ptr1_jax) # Copy result back into the provided torch output tensor res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res)) out_ptr0.copy_(res_t) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822 Approved by: https://github.com/jansel ghstack dependencies: #166976, #166982 --- test/inductor/test_pallas.py | 354 ++++++++++++++++++ torch/_inductor/async_compile.py | 36 ++ torch/_inductor/codegen/common.py | 2 + torch/_inductor/codegen/pallas.py | 424 ++++++++++++++++++++++ torch/_inductor/config.py | 5 +- torch/testing/_internal/inductor_utils.py | 9 +- torch/utils/_pallas.py | 82 +++++ 7 files changed, 907 insertions(+), 5 deletions(-) create mode 100644 test/inductor/test_pallas.py create mode 100644 torch/_inductor/codegen/pallas.py create mode 100644 torch/utils/_pallas.py diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py new file mode 100644 index 0000000000000..2d4e6af002ab0 --- /dev/null +++ b/test/inductor/test_pallas.py @@ -0,0 +1,354 @@ +# Owner(s): ["oncall: pt2"] +import functools +import sys +import unittest + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.testing import make_test_cls_with_patches +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS +from torch.testing._internal.inductor_utils import HAS_PALLAS +from torch.utils._triton import has_triton + + +if IS_WINDOWS and IS_CI: + sys.stderr.write( + "Windows CI does not have necessary dependencies for test_torchinductor yet\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + + +try: + from . import test_torchinductor +except ImportError: + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + + +test_classes = {} + + +def make_pallas(cls): + """Create a test class variant that uses Pallas backend.""" + suffix = "_pallas" + cls_prefix = "Pallas" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "cuda_backend", "pallas"), + xfail_prop="_expected_failure_pallas", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +class PallasTests(TestCase): + """Basic tests for Pallas backend functionality.""" + + def test_simple_add(self): + """Test basic element-wise addition.""" + + def fn(a, b): + return a + b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_simple_mul(self): + """Test basic element-wise multiplication.""" + + def fn(a, b): + return a * b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_sin(self): + """Test sin operation.""" + + def fn(x): + return torch.sin(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_fused_ops(self): + """Test fused operations (sin + add).""" + + def fn(x, y): + return x.sin() + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + y = torch.randn(1024, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_exp_log(self): + """Test exp and log operations.""" + + def fn(x): + return torch.log(torch.exp(x)) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_sqrt(self): + """Test sqrt operation.""" + + def fn(x): + return torch.sqrt(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_tanh(self): + """Test tanh operation.""" + + def fn(x): + return torch.tanh(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_abs_neg(self): + """Test abs and neg operations.""" + + def fn(x): + return torch.abs(-x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_maximum_minimum(self): + """Test maximum and minimum operations.""" + + def fn(a, b): + return torch.maximum(a, b) + torch.minimum(a, b) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + @unittest.skipUnless(has_triton(), "requires triton") + @unittest.skip("Random ops not yet implemented in Pallas backend") + def test_random_consistency(self): + """Test that random number generation is consistent across backends.""" + seed = 1234 + shape = (3, 3) + dtype = torch.float32 + + for rand_fn in [ + functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), + functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), + ]: + + @torch.compile(backend="inductor", options={"cuda_backend": "pallas"}) + def get_rand_pallas(): + return rand_fn() + + @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) + def get_rand_triton(): + return rand_fn() + + torch.manual_seed(seed) + pallas_output = get_rand_pallas() + torch.manual_seed(seed) + triton_output = get_rand_triton() + + self.assertEqual(pallas_output, triton_output) + + def test_compile_options(self): + """Test that Pallas backend is properly configured.""" + + @torch.compile( + backend="inductor", + options={"cuda_backend": "pallas"}, + ) + def pallas_fn(a, b): + return a.sin() + b.cos() + + _, (code,) = run_and_get_code( + pallas_fn, + torch.randn(64, device="cuda"), + torch.randn(64, device="cuda"), + ) + # Verify Pallas-specific code generation + self.assertIn("import jax", code) + self.assertIn("import jax.numpy as jnp", code) + self.assertIn("from jax.experimental import pallas as pl", code) + + def test_2d_tensor(self): + """Test with 2D tensors (though current implementation flattens).""" + + def fn(x, y): + return x + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(32, 32, device="cuda") + y = torch.randn(32, 32, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_different_shapes(self): + """Test with different tensor shapes.""" + + def fn(x): + return x * 2.0 + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + for shape in [(64,), (128,), (256,), (1024,)]: + x = torch.randn(shape, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_contiguous_index_validation(self): + """Test that contiguous index validation works correctly end-to-end.""" + + # Test 1: Contiguous operations should work + def contiguous_add(a, b): + return a + b + + compiled = torch.compile( + contiguous_add, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = contiguous_add(a, b) + self.assertEqual(result, expected) + + # Test 2: Operations on contiguous tensors should work + def contiguous_mul(x): + return x * 2.0 + + compiled = torch.compile( + contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(128, 8, device="cuda") + result = compiled(x) + expected = contiguous_mul(x) + self.assertEqual(result, expected) + + # Test 3: Non-contiguous views will fail at runtime with JAX/Pallas + # This demonstrates that the Pallas backend requires contiguous memory layout + def operate_on_tensor(x): + return x.sin() + + compiled = torch.compile( + operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"} + ) + + # Create a transposed (non-contiguous) view + x = torch.randn(64, 32, device="cuda") + x_t = x.t() # Non-contiguous view + self.assertFalse(x_t.is_contiguous()) + + # This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack + # The error indicates that our contiguous-only approach is correct + with self.assertRaises((RuntimeError, Exception)) as cm: + result = compiled(x_t) + + # Verify the error is related to layout/contiguous issues + error_msg = str(cm.exception) + self.assertTrue( + "layout" in error_msg.lower() + or "contiguous" in error_msg.lower() + or "non-default" in error_msg.lower(), + f"Expected layout/contiguous error, got: {error_msg}", + ) + + # But if we make it contiguous first, it should work + x_t_contiguous = x_t.contiguous() + self.assertTrue(x_t_contiguous.is_contiguous()) + result = compiled(x_t_contiguous) + expected = operate_on_tensor(x_t_contiguous) + self.assertEqual(result, expected) + + +# Create test variants using the main test suite +# Note: Only enable GPU tests since Pallas primarily targets GPU +if test_torchinductor.HAS_GPU and HAS_PALLAS: + # Uncomment these to run full test suite with Pallas backend + # make_pallas(test_torchinductor.SweepInputsGPUTest) + # make_pallas(test_torchinductor.GPUTests) + pass + +if __name__ == "__main__": + if HAS_PALLAS: + run_tests(needs="filelock") diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index a2c80002eb928..5ede0cd085010 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -601,6 +601,42 @@ def task(): future = self.submit(task) return LambdaFuture(lambda: future.result()) + def pallas(self, kernel_name: str, source_code: str): + """ + Compile Pallas (JAX experimental) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the Pallas kernel, as a string + + Note: + Pallas kernels are Python code that uses JAX and Pallas APIs. + We use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.pallas import MAIN_SUFFIX, PallasKernelWrapper + + kernel_code_log.info("Pallas Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find Pallas main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return PallasKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e6a5c5e8ec176..730c03f1c813c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -510,6 +510,7 @@ def init_backend_registration() -> None: from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling + from .pallas import PallasScheduling from .python_wrapper_mtia import PythonWrapperMtia from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen @@ -536,6 +537,7 @@ def init_backend_registration() -> None: cuda_backends = { "triton": CUDACombinedScheduling, "halide": HalideScheduling, + "pallas": PallasScheduling, } register_backend_for_device( "cuda", diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py new file mode 100644 index 0000000000000..1fc8e40724bc0 --- /dev/null +++ b/torch/_inductor/codegen/pallas.py @@ -0,0 +1,424 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import hashlib +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING + +import sympy # noqa: TC002 + +import torch # noqa: TC001 +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import get_fused_kernel_name, get_kernel_metadata +from ..virtualized import V +from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides +from .simd import SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from ..ir import IRNode + from ..scheduler import BaseSchedulerNode + + +# Main function suffix used in generated Pallas code +MAIN_SUFFIX = "main" + +# Logger for Pallas kernel code +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class PallasKernelWrapper: + """Wrapper to provide .run() interface for Pallas kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("Pallas kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the Pallas kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +class Unsupported(RuntimeError): + """Exception raised when an operation is not supported by the Pallas backend.""" + + +class PallasKernelOverrides(OpOverrides): + """ + Map element-wise ops to JAX/Pallas operations. + + For now, we use the default Python operators which are compatible + with JAX numpy broadcasting semantics. + """ + + @staticmethod + def sin(x: str) -> str: + return f"jnp.sin({x})" + + @staticmethod + def cos(x: str) -> str: + return f"jnp.cos({x})" + + @staticmethod + def tan(x: str) -> str: + return f"jnp.tan({x})" + + @staticmethod + def sinh(x: str) -> str: + return f"jnp.sinh({x})" + + @staticmethod + def cosh(x: str) -> str: + return f"jnp.cosh({x})" + + @staticmethod + def tanh(x: str) -> str: + return f"jnp.tanh({x})" + + @staticmethod + def asin(x: str) -> str: + return f"jnp.arcsin({x})" + + @staticmethod + def acos(x: str) -> str: + return f"jnp.arccos({x})" + + @staticmethod + def atan(x: str) -> str: + return f"jnp.arctan({x})" + + @staticmethod + def exp(x: str) -> str: + return f"jnp.exp({x})" + + @staticmethod + def exp2(x: str) -> str: + return f"jnp.exp2({x})" + + @staticmethod + def expm1(x: str) -> str: + return f"jnp.expm1({x})" + + @staticmethod + def log(x: str) -> str: + return f"jnp.log({x})" + + @staticmethod + def log10(x: str) -> str: + return f"jnp.log10({x})" + + @staticmethod + def log2(x: str) -> str: + return f"jnp.log2({x})" + + @staticmethod + def log1p(x: str) -> str: + return f"jnp.log1p({x})" + + @staticmethod + def sqrt(x: str) -> str: + return f"jnp.sqrt({x})" + + @staticmethod + def rsqrt(x: str) -> str: + return f"(1.0 / jnp.sqrt({x}))" + + @staticmethod + def abs(x: str) -> str: + return f"jnp.abs({x})" + + @staticmethod + def neg(x: str) -> str: + return f"(-{x})" + + @staticmethod + def floor(x: str) -> str: + return f"jnp.floor({x})" + + @staticmethod + def ceil(x: str) -> str: + return f"jnp.ceil({x})" + + @staticmethod + def trunc(x: str) -> str: + return f"jnp.trunc({x})" + + @staticmethod + def round(x: str) -> str: + return f"jnp.round({x})" + + @staticmethod + def sigmoid(x: str) -> str: + return f"(1.0 / (1.0 + jnp.exp(-{x})))" + + @staticmethod + def relu(x: str) -> str: + return f"jnp.maximum({x}, 0)" + + @staticmethod + def pow(a: str, b: str) -> str: + return f"jnp.power({a}, {b})" + + @staticmethod + def maximum(a: str, b: str) -> str: + return f"jnp.maximum({a}, {b})" + + @staticmethod + def minimum(a: str, b: str) -> str: + return f"jnp.minimum({a}, {b})" + + @staticmethod + def where(cond: str, a: str, b: str) -> str: + return f"jnp.where({cond}, {a}, {b})" + + +class PallasKernel(SIMDKernel): + """ + Minimal Pallas kernel for simple elementwise operations. + + Strategy: + - Treat loads as full-array refs: "in_ptrX[...]" + - Compute expression with Python operators (compatible with jax.numpy broadcasting) + - Store as full-array ref assignment: "out_ptrY[...] = " + - Generate Python code that defines a Pallas kernel and a host entrypoint. + - Use async_compile.pallas path to compile and load Python code. + """ + + overrides = PallasKernelOverrides # type: ignore[assignment] + + def _get_contiguous_index_str(self, index: sympy.Expr) -> str: + """ + Validate that the index represents contiguous access and return the indexing string. + + For Pallas, we only support simple contiguous access patterns where the index + is a single symbol (e.g., xindex) representing a flattened iteration. + This ensures the load/store order is contiguous. + + Args: + index: The indexing expression to validate + + Returns: + The indexing string to use (currently always "...") + + Raises: + Unsupported: If the index is not a simple contiguous pattern + """ + # Prepare and simplify the index + prepared_index = self.prepare_indexing(index) + + # For contiguous access, we expect a single symbol (like xindex) + # or a simple integer (for scalar operations) + if isinstance(prepared_index, sympy.Symbol): + # This is the expected case: a single symbol representing contiguous iteration + return "..." + elif prepared_index.is_Integer: + # Scalar case + return "..." + else: + # If there's any complex expression (ModularIndexing, FloorDiv, etc.), + # it's not a simple contiguous pattern + raise Unsupported( + f"Pallas backend only supports contiguous access patterns. " + f"Got complex index: {prepared_index}" + ) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override] + buf = self.args.input(name) + dtype = V.graph.get_dtype(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must be unpacked with [...] to load the array + return self.cse.generate( + self.compute, + f"{buf}[{index_str}]", + dtype=dtype, + ) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None + ) -> None: # type: ignore[override] + if mode is not None: + raise Unsupported("pallas store mode not supported") + out = self.args.output(name) + self.store_buffer_names.add(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must use [...] assignment to store back to the ref + self.stores.writeline(f"{out}[{index_str}] = {value}") + + def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override] + """ + Generate the complete Pallas kernel code as a Python string. + + This includes: + - Import statements for JAX/Pallas + - The kernel function that operates on refs + - The main wrapper function that handles PyTorch<->JAX conversions via DLPack + + Args: + name: Optional kernel name (will use placeholder if not provided) + + Returns: + str: Complete Python source code for the Pallas kernel + """ + # Ensure one (1) output for now + live_outs = list(self.args.live_output_buffers()) + if len(live_outs) != 1: + raise Unsupported( + "Pallas backend currently supports single-output elementwise kernels only" + ) + + code = IndentedBuffer() + code.splice( + """ + import torch + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + from torch.utils import dlpack as torch_dlpack + """, + strip=True, + ) + + # Define the Pallas kernel: accepts refs, uses broadcasted expressions + arg_defs, _, _, _ = self.args.python_argdefs() + # Order: inputs (in_ptr*), then outputs (out_ptr*), then sizes/workspaces + kernel_params = [a.name for a in arg_defs] + + kernel_name = name or "" + code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):") + with code.indent(): + # Emit compute (CSE) and store lines; they reference *_ptr[...] directly + for line in self.compute._lines: + code.writeline(str(line)) + for line in self.stores._lines: + code.writeline(str(line)) + + # Host entry: convert torch tensors <-> jax, call pallas_call and copy back + main_name = f"{kernel_name}_main" + code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):") + with code.indent(): + # Identify inputs (in_ptr*) and output (out_ptr*) + input_params = [ + p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr")) + ] + output_params = [p for p in kernel_params if p.startswith("out_ptr")] + + if len(output_params) != 1: + raise RuntimeError( + f"Expected exactly 1 output, got {len(output_params)}" + ) + + output_param = output_params[0] + + # Convert inputs to JAX arrays + code.writeline("# Convert Torch -> JAX for inputs") + for inp in input_params: + code.writeline( + f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))" + ) + + # Get output spec from PyTorch tensor + code.writeline("# Prepare output spec from PyTorch tensor") + code.writeline("# Map PyTorch dtype to JAX dtype string") + code.writeline("_torch_dtype_to_jax = {") + code.writeline( + " torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16," + ) + code.writeline( + " torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8," + ) + code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,") + code.writeline("}") + code.writeline( + f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])" + ) + + # Call pallas + code.writeline("compiled = pl.pallas_call(") + code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),") + code.writeline(" out_shape=out_spec,") + code.writeline(" grid=(1,),") + code.writeline(")") + + jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params]) + code.writeline(f"res = compiled({jax_input_args})") + + # Copy result back + code.writeline("# Copy result back into the provided torch output tensor") + code.writeline( + "res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))" + ) + code.writeline(f"{output_param}.copy_(res_t)") + + return code.getvalue() + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: # type: ignore[override] + """Generate the Python code that calls this Pallas kernel.""" + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + + # Generate kernel call: kernel_name.run(arg1, arg2, ...) + # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper + # which exposes a run() method + kernel_call = f"{name}.run({', '.join(map(str, call_args))})" + wrapper.writeline(kernel_call) + + +class PallasScheduling(SIMDScheduling): + kernel_type = PallasKernel # type: ignore[assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + # Start minimal: no special features advertised + return OrderedSet() + + def define_kernel( + self, + src_code: str, + node_schedule: Sequence[BaseSchedulerNode], + kernel: PallasKernel, + ) -> str: # type: ignore[override] + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + return wrapper.src_to_kernel[src_code] + + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"pallas_{kernel_hash}" + else: + kernel_name = f"pallas_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + + # Replace placeholder if any + src_code = src_code.replace("", kernel_name) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment) + + return kernel_name diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 457f86fe7a77e..66eaf69dd59a8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1950,8 +1950,9 @@ class rocm: # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" -# Backend to use for CUDA codegen either "triton" or "halide" (experimental) -cuda_backend: Literal["triton", "halide"] = "triton" +# Backend to use for CUDA codegen either +# "triton", "halide" (experimental) or "pallas" (experimental) +cuda_backend: Literal["triton", "halide", "pallas"] = "triton" # Backend to use for XPU codegen either "triton" xpu_backend: Literal["triton"] = "triton" diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index bd11e01a80250..6bd34c812d641 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -33,6 +33,10 @@ OrderedSet, ) from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._helion import has_helion +from torch.utils._pallas import has_pallas +from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -43,9 +47,6 @@ LazyVal, TestCase, ) -from torch.utils._config_module import ConfigModule -from torch.utils._helion import has_helion -from torch.utils._triton import has_triton log: logging.Logger = logging.getLogger(__name__) @@ -67,6 +68,8 @@ def test_cpu(): HAS_TRITON = has_triton() +HAS_PALLAS = has_pallas() + HAS_HELION = has_helion() if HAS_TRITON: diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py new file mode 100644 index 0000000000000..25cc635dbb178 --- /dev/null +++ b/torch/utils/_pallas.py @@ -0,0 +1,82 @@ +import functools + +import torch + + +@functools.cache +def has_jax_package() -> bool: + """Check if JAX is installed.""" + try: + import jax # noqa: F401 # type: ignore[import-not-found] + + return True + except ImportError: + return False + + +@functools.cache +def has_pallas_package() -> bool: + """Check if Pallas (JAX experimental) is available.""" + if not has_jax_package(): + return False + try: + from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found] + pallas as pl, + ) + + return True + except ImportError: + return False + + +@functools.cache +def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]: + """Get JAX version as (major, minor, patch) tuple.""" + try: + import jax # type: ignore[import-not-found] + + version_parts = jax.__version__.split(".") + major, minor, patch = (int(v) for v in version_parts[:3]) + return (major, minor, patch) + except (ImportError, ValueError, AttributeError): + return fallback + + +@functools.cache +def has_jax_cuda_backend() -> bool: + """Check if JAX has CUDA backend support.""" + if not has_jax_package(): + return False + try: + import jax # type: ignore[import-not-found] + + # Check if CUDA backend is available + devices = jax.devices("gpu") + return len(devices) > 0 + except Exception: + return False + + +@functools.cache +def has_pallas() -> bool: + """ + Check if Pallas backend is fully available for use. + + Requirements: + - JAX package installed + - Pallas (jax.experimental.pallas) available + - CUDA backend available (for GPU support) + """ + if not has_pallas_package(): + return False + + # Only enable Pallas if CUDA is available + # (Pallas primarily targets GPU workloads) + if not torch.cuda.is_available(): + return False + + # Check if JAX has GPU/CUDA backend + if not has_jax_cuda_backend(): + return False + + return True From 39160dba0c5120c65705a44e556c8c4af243e573 Mon Sep 17 00:00:00 2001 From: Bruce Chang Date: Wed, 5 Nov 2025 00:54:35 +0000 Subject: [PATCH 039/651] shrink_group implementation to expose ncclCommShrink API (#164518) Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/kwen2501 --- docs/source/distributed.md | 4 + test/distributed/test_c10d_nccl.py | 676 +++++++++++++++++- torch/csrc/distributed/c10d/Backend.hpp | 17 + torch/csrc/distributed/c10d/NCCLUtils.cpp | 59 ++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 12 + .../distributed/c10d/ProcessGroupNCCL.cpp | 135 +++- .../distributed/c10d/ProcessGroupNCCL.hpp | 21 + torch/csrc/distributed/c10d/init.cpp | 11 + torch/distributed/distributed_c10d.py | 519 ++++++++++++++ torch/testing/_internal/common_distributed.py | 48 ++ 10 files changed, 1500 insertions(+), 2 deletions(-) diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 1c9d374b8ab02..ca1fe3b5e9099 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective .. autofunction:: new_group ``` +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.shrink_group +``` + ```{eval-rst} .. autofunction:: get_group_rank ``` diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c117bc810b115..cf53896187c20 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2,6 +2,7 @@ import copy import json +import logging import os import pickle import random @@ -21,6 +22,7 @@ import torch import torch.distributed as c10d import torch.distributed._functional_collectives as _functional_collectives +from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT if not c10d.is_available() or not c10d.is_nccl_available(): @@ -47,12 +49,15 @@ from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( + get_required_world_size, get_timeout, init_multigpu_helper, MultiProcessTestCase, requires_multicast_support, requires_nccl, + requires_nccl_shrink, requires_nccl_version, + requires_world_size, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, @@ -88,6 +93,53 @@ ) +_start_time = time.time() +_logger = logging.getLogger(__name__) + + +def _ts(): + return time.time() - _start_time + + +def configure(level=logging.INFO, force=False): + try: + logging.basicConfig( + level=level, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + force=force, + ) + except TypeError: + logging.basicConfig( + level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s" + ) + + +def log_test_info(rank, message): + _logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message) + + +def log_test_success(rank, message): + _logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message) + + +def log_test_validation(rank, message): + _logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message) + + +def log_test_warning(rank, message): + _logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message) + + +def log_test_error(rank, message): + _logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message) + + +_log_configure = configure + + +_log_configure(level=logging.INFO, force=True) + + class RendezvousEnvTest(TestCase): @retry_on_connect_failures @requires_nccl() @@ -317,7 +369,7 @@ def tearDown(self): @property def world_size(self): - return 2 + return get_required_world_size(self, 2) @property def rank_to_GPU(self): @@ -1255,6 +1307,628 @@ def test_set_process_group_desc(self): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_basic(self): + """Test basic shrink_group functionality.""" + self._perform_shrink_test([1], "Basic shrink test") + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_validation(self): + """Test input validation in shrink_group.""" + device, pg = self._setup_shrink_test("validation") + + def _test_invalid_input(ranks, description, expected_exception): + """Helper to test invalid inputs.""" + try: + c10d.shrink_group(ranks) + self.fail(f"Expected {expected_exception.__name__} for {description}") + except expected_exception: + log_test_validation(self.rank, f"✓ {description}") + except Exception: + if expected_exception is Exception: # Accept any exception + log_test_validation(self.rank, f"✓ {description}") + else: + raise + + # Test cases + _test_invalid_input([], "Empty exclusion list", ValueError) + if self.world_size > 1: + _test_invalid_input([0, 0, 1], "Duplicate ranks", Exception) + _test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception) + + log_test_success(self.rank, "All validation tests passed") + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_backend_properties(self): + """Test that backend properties are preserved after shrinking.""" + + test_name = "Backend Properties Test" + ranks_to_exclude = [0] + + # Reuse _setup_shrink_test for complete setup (device, environment, and process group) + device, pg = self._setup_shrink_test("backend_properties") + + # Follow _perform_shrink_test pattern from here + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # Store original backend property values (not references) before shrinking + original_timeout = None + original_high_priority = None + if not is_excluded: + original_backend = pg._get_backend(device) + original_timeout = original_backend.options._timeout + original_high_priority = original_backend.options.is_high_priority_stream + log_test_info( + self.rank, + f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}", + ) + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + dist.destroy_process_group() # hang without it + return + + # Only non-excluded ranks proceed with shrink (same as _perform_shrink_test) + log_test_info(self.rank, "Non-excluded rank calling shrink_group") + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + + # Reuse _validate_shrunk_group helper (same as _perform_shrink_test) + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + # Add custom backend properties validation + new_backend = shrunk_pg._get_backend(device) + log_test_info(self.rank, "Validating backend properties are preserved") + + new_timeout = new_backend.options._timeout + new_high_priority = new_backend.options.is_high_priority_stream + + log_test_info( + self.rank, + f"Timeout comparison - original: {original_timeout}, new: {new_timeout}", + ) + self.assertEqual( + original_timeout, new_timeout, f"{test_name}: timeout not preserved" + ) + + log_test_info( + self.rank, + f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}", + ) + self.assertEqual( + original_high_priority, + new_high_priority, + f"{test_name}: high_priority_stream not preserved", + ) + + log_test_validation( + self.rank, f"{test_name}: Backend properties preserved successfully" + ) + log_test_success( + self.rank, f"{test_name} successful (shrink + backend validation)" + ) + + # Cleanup (same as _perform_shrink_test) + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_multiple_comms(self): + """Test shrink_group with multiple communicators and subgroup invalidation.""" + + device, pg = self._setup_shrink_test("multiple_comms") + + # Create subgroup [0, 1] and test shrinking it + subgroup = c10d.new_group([0, 1]) + if self.rank <= 1: + # Shrink subgroup: exclude rank 1 + if self.rank == 0: # Only rank 0 remains + shrunk_subgroup = c10d.shrink_group([1], group=subgroup) + self.assertEqual(shrunk_subgroup.size(), 1) + # Test communication on shrunk subgroup + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_subgroup) + self.assertEqual(tensor.item(), 0) # Only rank 0 + log_test_success(self.rank, "Subgroup shrinking successful") + + dist.barrier() # Sync before default group test + + # Shrink default group: exclude last rank + ranks_to_exclude = [self.world_size - 1] + if self.rank not in ranks_to_exclude: + shrunk_default = c10d.shrink_group(ranks_to_exclude) + expected_size = self.world_size - 1 + self.assertEqual(shrunk_default.size(), expected_size) + + # Test collective on shrunk default group + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_default) + expected_sum = sum( + range(self.world_size - 1) + ) # 0 + 1 + ... + (world_size-2) + self.assertEqual(tensor.item(), expected_sum) + log_test_success(self.rank, "Default group shrinking successful") + + # Note: After shrinking default group, the old subgroup is invalid + # due to global rank reassignment + + dist.destroy_process_group() + + def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude): + """Helper method to test shrink_group with a specific flag.""" + if self.world_size < 2: + log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})") + return + ranks_to_exclude = [rank_to_exclude] + log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})") + if flag_name == "NCCL_SHRINK_ABORT": + log_test_info( + self.rank, + "ABORT flag will terminate ongoing operations before shrinking", + ) + + self._perform_shrink_test( + ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag + ) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_flags(self): + """Test shrink_group with different shrink flags.""" + # Test ABORT flags + log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag") + self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_nccl_config(self): + """Verify that passing NCCL config via pg_options influences the shrunk group's backend options.""" + device, pg = self._setup_shrink_test("config") + if self.rank == self.world_size - 1: + # excluded rank should not call shrink_group + dist.destroy_process_group() + return + + # Prepare pg_options with NCCL config overrides + # Capture parent's current backend options to ensure we can prove override vs inherit + parent_backend = pg._get_backend(torch.device("cuda")) + parent_hp = parent_backend.options.is_high_priority_stream + parent_blocking = parent_backend.options.config.blocking + + # Choose overrides that differ from the parent (flip where possible) + override_hp = not parent_hp + if parent_blocking in (0, 1): + override_blocking = 1 - parent_blocking + else: + # If undefined or unexpected, set to 1 which is a concrete value + override_blocking = 1 + + opts = c10d.ProcessGroupNCCL.Options() + opts.is_high_priority_stream = override_hp + opts.config.blocking = override_blocking + + shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts) + + # Validate backend options propagated + backend = shrunk_pg._get_backend(torch.device("cuda")) + # is_high_priority_stream should exactly match our override and differ from parent + self.assertEqual(backend.options.is_high_priority_stream, override_hp) + self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp) + # config is a struct; check representative field and difference from parent when meaningful + self.assertEqual(backend.options.config.blocking, override_blocking) + if parent_blocking in (0, 1): + self.assertNotEqual(backend.options.config.blocking, parent_blocking) + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_performance(self): + """Test shrink_group performance and regression detection.""" + import time + + ranks_to_exclude = self._get_default_ranks_to_exclude() + is_excluded = self.rank in ranks_to_exclude + + if not ranks_to_exclude: + log_test_info(self.rank, "Skipping performance test (world_size=1)") + return + + log_test_info(self.rank, f"Performance test with {self.world_size} processes") + device, pg = self._setup_shrink_test("performance") + + if not is_excluded: + log_test_info(self.rank, "Measuring shrink_group performance") + start_time = time.time() + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + end_time = time.time() + + elapsed_time = end_time - start_time + log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s") + + # Regression check: should complete within reasonable time + self.assertLess( + elapsed_time, + 30.0, + f"shrink_group took {elapsed_time:.3f}s, possible regression", + ) + + # Test collective performance + expected_size = self.world_size - len(ranks_to_exclude) + self._validate_shrunk_group(shrunk_pg, expected_size, "performance") + + collective_start = time.time() + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, "performance" + ) + collective_time = time.time() - collective_start + + log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s") + log_test_success(self.rank, "Performance test passed") + else: + log_test_info(self.rank, "Excluded rank - waiting") + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(4) + def test_shrink_group_multiple_exclusions(self): + """Test shrink_group with multiple ranks excluded at once.""" + # Scale exclusions with world size + ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2 + + self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test") + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_multiple_iterations(self): + """Test multiple shrink operations in sequence.""" + log_test_info( + self.rank, + f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}", + ) + + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + _ = self._create_process_group_nccl(store, self.opts(), device_id=device) + + # Track current effective world size throughout shrinking operations + current_world_size = self.world_size + log_test_info(self.rank, f"Initial world_size: {current_world_size}") + + # First shrinking: exclude the last rank(s) + first_exclusion = [self.world_size - 1] + if self.world_size >= 6: + first_exclusion.append( + self.world_size - 2 + ) # Exclude last two ranks for larger sizes + + log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}") + + if self.rank not in first_exclusion: + # Only non-excluded ranks should call shrink_group + first_pg = c10d.shrink_group(first_exclusion) + self.assertIsNotNone(first_pg) + # IMPORTANT: Update world size after first shrinking + current_world_size = first_pg.size() + expected_first_size = self.world_size - len(first_exclusion) + log_test_info( + self.rank, + f"After first shrinking: world_size {self.world_size} -> {current_world_size}", + ) + self.assertEqual(first_pg.size(), expected_first_size) + + # Second shrinking: exclude another rank from the remaining group + # Choose a rank that's in the middle range + if current_world_size >= 3: + second_exclusion = [ + current_world_size - 1 + ] # Exclude the new "last" rank + log_test_info( + self.rank, + f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}", + ) + + if self.rank not in second_exclusion: + # Only non-excluded ranks should call shrink_group for second iteration + second_pg = c10d.shrink_group(second_exclusion, group=first_pg) + self.assertIsNotNone(second_pg) + # IMPORTANT: Update world size after second shrinking + final_world_size = second_pg.size() + expected_final_size = current_world_size - len(second_exclusion) + log_test_info( + self.rank, + f"After second shrinking: world_size {current_world_size} -> {final_world_size}", + ) + self.assertEqual(second_pg.size(), expected_final_size) + + # Test collective on final group + tensor = torch.full((1,), self.rank).cuda(device) + log_test_info( + self.rank, + f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}", + ) + c10d.all_reduce(tensor, group=second_pg) + log_test_info( + self.rank, + f"Final all_reduce completed, result: {tensor.item()}", + ) + + # Calculate expected sum of remaining ranks + all_excluded = set(first_exclusion + second_exclusion) + remaining_ranks = [ + r for r in range(self.world_size) if r not in all_excluded + ] + expected_sum = sum(remaining_ranks) + log_test_info( + self.rank, + f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}", + ) + self.assertEqual(tensor.item(), expected_sum) + log_test_info(self.rank, "Final verification passed") + else: + log_test_info( + self.rank, + "This rank excluded in second shrinking, not calling shrink_group", + ) + else: + log_test_info( + self.rank, "Skipping second shrinking (remaining group too small)" + ) + else: + log_test_info( + self.rank, + "This rank excluded in first shrinking, not calling shrink_group", + ) + + log_test_info(self.rank, "Destroying process group") + dist.destroy_process_group() + log_test_info(self.rank, "test_shrink_group_multiple_iterations completed") + + # Helper methods for optimized shrink group tests + def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True): + """Common setup for shrink group tests.""" + os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" + world_size = world_size or self.world_size + store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size) + device = torch.device(f"cuda:{self.rank}") + c10d.init_process_group( + "nccl", + world_size=world_size, + rank=self.rank, + store=store, + pg_options=self.opts(), + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + + if warmup: + c10d.all_reduce(torch.ones(1).cuda(device), group=pg) + + return device, pg + + def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""): + """Validate properties of a shrunk process group.""" + self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None") + actual_size = shrunk_pg.size() + self.assertEqual( + actual_size, expected_size, f"{test_name}: group size mismatch" + ) + + new_rank = shrunk_pg.rank() + self.assertTrue( + 0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}" + ) + + log_test_info( + self.rank, + f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}", + ) + return new_rank + + def _test_collective_on_shrunk_group( + self, shrunk_pg, device, ranks_to_exclude, test_name="" + ): + """Test collective communication on shrunk group and verify correctness.""" + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + c10d.all_reduce(test_tensor, group=shrunk_pg) + + result = test_tensor.item() + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + + self.assertEqual( + result, expected_sum, f"{test_name}: collective result mismatch" + ) + log_test_info( + self.rank, f"{test_name}: collective passed ({result} == {expected_sum})" + ) + return result + + def _perform_shrink_test( + self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True + ): + """Complete shrink test flow: setup, shrink, validate, test collective, cleanup. + + Consistent API: All ranks perform setup to initialize distributed environment. + ONLY non-excluded ranks call shrink_group() for both default and non-default groups. + Excluded ranks perform setup, then exit without calling shrink_group() or waiting. + """ + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # All ranks (including excluded ones) perform setup to initialize distributed environment + device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_")) + is_default_group = pg == c10d.distributed_c10d._get_default_group() + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + if shrink_flags & NCCL_SHRINK_ABORT: + log_test_info(self.rank, f"Using abort for excluded rank {self.rank}") + pg._get_backend(torch.device(device)).abort() + log_test_info( + self.rank, f"cleanup resources for excluded rank {self.rank}" + ) + dist.destroy_process_group() + log_test_info(self.rank, f"Excluded rank {self.rank} - exit") + else: + log_test_info( + self.rank, f"Using regular destroy for excluded rank {self.rank}" + ) + dist.destroy_process_group() + return None + + # Only non-excluded ranks proceed with shrink + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group})", + ) + shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags) + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done", + ) + + # Non-excluded ranks: validate and test the new group + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + if with_collective: + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, test_name + ) + log_test_success(self.rank, f"{test_name} successful (shrink + collective)") + else: + log_test_success(self.rank, f"{test_name} successful (shrink only)") + + dist.destroy_process_group() + return shrunk_pg + + def _get_default_ranks_to_exclude(self): + """Get default ranks to exclude based on world size.""" + if self.world_size <= 1: + return [] + return [self.world_size - 1] # Exclude last rank by default + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_vs_abort_reinit_performance(self): + """Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability).""" + log_test_info(self.rank, "=== TEST 1: abort+reinit ===") + + device, pg1 = self._setup_shrink_test("_perf_reinit") + torch.cuda.synchronize(device) + + # Test 1: Traditional abort + reinit + start_time = time.perf_counter() + dist.destroy_process_group() + + device, new_pg = self._setup_shrink_test("perf_shrink_test1") + reinit_time = time.perf_counter() - start_time + + # Test collective with original rank values for fair comparison (non-blocking mode) + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True) + work.wait() + + torch.cuda.synchronize(device) + + # Verify correctness + expected_sum = sum(r for r in range(self.world_size)) + self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed") + + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + dist.destroy_process_group(new_pg) + + # Test 2: shrink_group with NCCL_SHRINK_ABORT + log_test_info(self.rank, "=== TEST 2: shrink_group ===") + + ranks_to_exclude = [self.world_size - 1] + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix + + shrink_time = 0 + if not is_excluded: + torch.cuda.synchronize(device) # Ensure accurate timing + start_time = time.perf_counter() + shrunk_pg = c10d.shrink_group( + ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT + ) + c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg) + shrink_time = time.perf_counter() - start_time + + # Test collective communication on shrunk group (non-blocking mode) + test_tensor = torch.full( + (1,), self.rank, device=device, dtype=torch.float32 + ) + work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True) + work.wait() + + # Verify correctness + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + self.assertEqual( + test_tensor.item(), + expected_sum, + "shrink_test: collective result mismatch", + ) + + torch.cuda.synchronize(device) # Ensure operations complete + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + dist.destroy_process_group() + else: + log_test_info(self.rank, "Excluded from shrink test - exiting immediately") + dist.destroy_process_group() + return + + # Performance analysis (only for participating ranks) + if shrink_time > 0 and reinit_time > 0: + speedup = reinit_time / shrink_time + time_saved = reinit_time - shrink_time + + log_test_info(self.rank, "=== PERFORMANCE RESULTS ===") + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s") + log_test_info(self.rank, f"speedup: {speedup:.2f}x") + + if speedup > 1.1: + log_test_success(self.rank, "shrink_group significantly faster") + elif speedup > 0.9: + log_test_info(self.rank, "≈ comparable performance") + else: + log_test_warning(self.rank, "abort+reinit faster") + + log_test_info(self.rank, "Performance test completed") + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_deterministic_mode_no_break(self): diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 6ffa1529a4de0..72e35e3fc9dd3 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } + virtual bool supportsShrinking() const { + return false; + } + + // Shrink the backend by excluding specified ranks. Backends that support + // communicator shrinking should override this and return a new backend + // instance representing the shrunken group. Backends may use opts_override + // to supply backend-specific options for the new group. + virtual c10::intrusive_ptr shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/ = 0, + const c10::intrusive_ptr& /*opts_override*/ = nullptr) { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support shrink")); + } + virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 8074cc98a04f1..a41f654b9ae20 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -259,6 +259,65 @@ std::shared_ptr NCCLComm::split( } #endif +#ifdef NCCL_HAS_COMM_SHRINK +std::shared_ptr NCCLComm::shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags) { + // Preconditions are validated in ProcessGroupNCCL::shrink + + LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr() + << " excluding " << ranks_to_exclude.size() << " ranks"; + + at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); + auto comm = std::make_shared(); + + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); + + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommShrink( + sourceComm, + ranks_to_exclude.data(), + ranks_to_exclude.size(), + reinterpret_cast(&(comm->ncclComm_)), + config, + shrinkFlags), + source->getNcclCommFailureReason()); + + // Wait for the child communicator to be ready + source->waitReady(true); + comm->initialized_ = true; + + // NCCL automatically assigns rank during shrink - query it efficiently + int assigned_rank; + try { + C10D_NCCL_CHECK( + ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt); + comm->rank_ = assigned_rank; + } catch (const std::exception& e) { + // Fallback: if ncclCommUserRank fails, we can't determine the rank + LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what(); + throw; + } + + // Child comm should be on the same device as parent comm + comm->deviceIndex_ = source->deviceIndex_; + if (config != nullptr) { + comm->nonBlocking_ = config->blocking == 0; + } else { + // Inherit parent behavior if no config provided + comm->nonBlocking_ = source->nonBlocking_; + } + + LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm " + << comm->repr() << " with NCCL-assigned rank " << assigned_rank; + + return comm; +} +#endif + void NCCLComm::finalize() { LockType lock(mutex_); if (aborted_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index fdd50f69ef3d7..142633b823744 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -90,6 +90,10 @@ static_assert( #define NCCL_HAS_NVLS_CTAS #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_SHRINK +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -294,6 +298,14 @@ class NCCLComm { ncclConfig_t& config); #endif // NCCL_HAS_COMM_SPLIT +#ifdef NCCL_HAS_COMM_SHRINK + static std::shared_ptr shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags = 0); +#endif // NCCL_HAS_COMM_SHRINK + #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fd7f0b4246517..d051803aa7376 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp( } // Get a key string from device -inline std::string getKeyFromDevice(at::Device& device) { +inline std::string getKeyFromDevice(const at::Device& device) { return std::to_string(device.index()); } @@ -5842,6 +5842,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor( return tensor; } +#ifdef NCCL_HAS_COMM_SHRINK +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& ranks_to_exclude, + int shrink_flags, + const c10::intrusive_ptr& opts_override) { + // Runtime version check with better error message + auto runtime_version = torch::cuda::nccl::version(); + TORCH_CHECK( + runtime_version >= NCCL_VERSION(2, 27, 0), + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. " + "Found version: ", + runtime_version); + + // Early validation with detailed error messages + TORCH_CHECK_VALUE( + !ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty"); + TORCH_CHECK_VALUE( + static_cast(ranks_to_exclude.size()) < size_, + "Cannot exclude all ranks (", + ranks_to_exclude.size(), + " >= ", + size_, + ")"); + + // Validate ranks and convert to int efficiently + std::vector int_ranks_to_exclude; + int_ranks_to_exclude.reserve(ranks_to_exclude.size()); + for (int64_t rank : ranks_to_exclude) { + TORCH_CHECK_VALUE( + rank >= 0 && rank < size_, + "Invalid rank ", + rank, + " for group size ", + size_); + int_ranks_to_exclude.push_back(static_cast(rank)); + } + + // Get primary communicator with better error context + auto primary_device_index = guessDeviceId(); + auto primary_device = at::Device(at::kCUDA, primary_device_index); + const auto primary_key = getKeyFromDevice(primary_device); + + std::shared_ptr primary_comm = getNCCLComm(primary_key); + TORCH_CHECK( + primary_comm, + "Primary NCCL communicator for device ", + primary_device, + " (key: ", + primary_key, + ") is not initialized"); + + // Cache device index before shrink operation + at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex(); + + ncclConfig_t* config = nullptr; + // Default to inheriting from parent options + bool high_priority_stream = options_->is_high_priority_stream; + if (opts_override) { + auto nccl_opts = + c10::static_intrusive_pointer_cast( + opts_override); + config = &nccl_opts->config; + // If user provided override options, honor is_high_priority_stream as well + high_priority_stream = nccl_opts->is_high_priority_stream; + } + + std::shared_ptr shrunk_comm = NCCLComm::shrink( + primary_comm.get(), + int_ranks_to_exclude, + (config != nullptr ? config : &options_->config), + shrink_flags); + + // Calculate new size and get NCCL-assigned rank + int new_size = size_ - static_cast(ranks_to_exclude.size()); + int new_rank = shrunk_comm->rank_; + + // Create new ProcessGroupNCCL with optimized options cloning + auto new_store = store_->clone(); + auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream); + new_opts->timeout = options_->timeout; + if (config != nullptr) { + new_opts->config = *config; + } else { + new_opts->config = options_->config; + } + + auto new_pg = c10::make_intrusive( + new_store, new_rank, new_size, new_opts); + + // Set up the new process group with optimized device setup + new_pg->initializeDeviceStateForComm( + at::Device(at::kCUDA, parent_device_index), shrunk_comm); + + return c10::static_intrusive_pointer_cast(new_pg); +} + +#else // !NCCL_HAS_COMM_SHRINK +// Backend interface override: raise consistent error when shrink is +// unsupported. +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/, + const c10::intrusive_ptr& /*opts_override*/) { + TORCH_CHECK( + false, + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, " + "but PyTorch was built with an older version or without NCCL shrink support."); +} + +#endif // NCCL_HAS_COMM_SHRINK + +void ProcessGroupNCCL::initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm) { + const auto key = getKeyFromDevice(device); + std::unique_lock lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto stream = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + devNCCLCommMap_[key] = comm; + ncclStreams_.emplace(key, stream); + ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming)); + usedDeviceIdxs_.insert(device.index()); + + if (shouldAllCommunicatorsRegisterAllTensors()) { + std::lock_guard map_lock(ncclCommMemPoolMapMutex); + ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{}); + } +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 286eab14d1a86..2ead1a107394d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend { ErrorType getError() override; + bool supportsShrinking() const override { +#ifdef NCCL_HAS_COMM_SHRINK + return true; +#else + return false; +#endif + } + + // Backend-style shrink override that returns a Backend instance. + c10::intrusive_ptr shrink( + const std::vector& ranks_to_exclude, + int shrink_flags = 0, + const c10::intrusive_ptr& opts_override = + nullptr) override; + std::shared_ptr getMemAllocator() override; // Allocate tensor from communication-optimized memory pool @@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { int p2pRank = 0, bool isSendRecvSelf = false); + // Initialize device-specific state (comm, stream, event, bookkeeping) for a + // given communicator on this process group instance. + void initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm); + // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a6c6c6f8c4744..4c6bdbe2ce70f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2734,12 +2734,23 @@ The hook must have the following signature: "supports_time_estimate", &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") + .def_property_readonly( + "supports_shrinking", + &::c10d::Backend::supportsShrinking, + "(test whether the backend supports communicator shrinking)") .def( "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") + .def( + "shrink", + &::c10d::Backend::shrink, + py::arg("ranks_to_exclude"), + py::arg("shrink_flags") = 0, + py::arg("opts_override") = nullptr, + py::call_guard()) .def( "broadcast", &::c10d::Backend::broadcast, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bc79408a32ff9..9e4ec1483e960 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -130,6 +130,7 @@ "reduce_scatter_tensor", "get_node_local_rank", "split_group", + "shrink_group", ] _MPI_AVAILABLE = True @@ -5753,3 +5754,521 @@ def _get_process_group_name(pg: ProcessGroup) -> str: def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + + +# Shrink flags for process group backends +SHRINK_DEFAULT = 0x00 +SHRINK_ABORT = 0x01 + + +@_time_logger +def shrink_group( + ranks_to_exclude: list[int], + group: Optional[ProcessGroup] = None, + shrink_flags: int = SHRINK_DEFAULT, + pg_options: Optional[Any] = None, +) -> ProcessGroup: + """ + Shrinks a process group by excluding specified ranks. + + Creates and returns a new, smaller process group comprising only the ranks + from the original group that were not in the ``ranks_to_exclude`` list. + + Args: + ranks_to_exclude (List[int]): A list of ranks from the original + ``group`` to exclude from the new group. + group (ProcessGroup, optional): The process group to shrink. If ``None``, + the default process group is used. Defaults to ``None``. + shrink_flags (int, optional): Flags to control the shrinking behavior. + Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. + ``SHRINK_ABORT`` will attempt to terminate ongoing operations + in the parent communicator before shrinking. + Defaults to ``SHRINK_DEFAULT``. + pg_options (ProcessGroupOptions, optional): Backend-specific options to apply + to the shrunken process group. If provided, the backend will use + these options when creating the new group. If omitted, the new group + inherits defaults from the parent. + + Returns: + ProcessGroup: a new group comprised of the remaining ranks. If the + default group was shrunk, the returned group becomes the new default group. + + Raises: + TypeError: if the group’s backend does not support shrinking. + ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, + duplicates, or excludes all ranks). + RuntimeError: if an excluded rank calls this function or the backend + fails the operation. + + Notes: + - Only non-excluded ranks should call this function; excluded ranks + must not participate in the shrink operation. + - Shrinking the default group destroys all other process groups since + rank reassignment makes them inconsistent. + """ + # Step 1: Validate input parameters with comprehensive error checking + _validate_shrink_inputs(ranks_to_exclude, shrink_flags) + + # Step 2: Get target group and essential properties + target_group_info = _prepare_shrink_target_group(group) + + # Step 3: Validate backend requirements and availability + backend_impl = _validate_shrink_backend_requirements(target_group_info) + + # Step 4: Validate ranks against group and check for duplicates + excluded_ranks_set = _validate_and_process_excluded_ranks( + ranks_to_exclude, target_group_info + ) + + # Step 5: Execute the actual shrink operation (backend-specific) + new_backend = backend_impl.shrink( + sorted(excluded_ranks_set), + shrink_flags, + pg_options if pg_options is not None else None, + ) + + # Step 6: Handle cleanup and creation of new process group + target_group_info["pg_options_override"] = pg_options + return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) + + +def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: + """Validate input parameters for shrink_group.""" + if not isinstance(ranks_to_exclude, list): + raise TypeError( + f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " + f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." + ) + + if not ranks_to_exclude: + raise ValueError( + "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " + "one rank to exclude. Example: [failed_rank_id]" + ) + + # Validate shrink_flags with clear explanation of valid values + valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] + if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: + raise ValueError( + f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " + f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " + f"Use SHRINK_ABORT to abort ongoing operations before shrinking." + ) + + +def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: + """Prepare and validate the target group for shrinking.""" + target_pg = group if group is not None else _get_default_group() + + # Cache frequently accessed properties to avoid repeated calls + group_size = int(target_pg.size()) + group_info = { + "process_group": target_pg, + "is_default_group": (target_pg == _get_default_group()), + "group_size": group_size, + "current_rank": target_pg.rank(), + "group_name": _get_process_group_name(target_pg), + } + + # Validate that we have a valid process group + if group_size <= 1: + raise ValueError( + f"Cannot shrink a process group with size {group_size}. " + f"Group must have at least 2 ranks to support shrinking." + ) + + return group_info + + +def _validate_shrink_backend_requirements(group_info: dict) -> Any: + """Return the backend implementation for the target group or raise if unsupported.""" + target_pg = group_info["process_group"] + group_name = group_info["group_name"] + + # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, + # otherwise try CUDA then fall back to CPU. + try: + preferred_device = getattr(target_pg, "bound_device_id", None) + if preferred_device is not None: + backend_impl = target_pg._get_backend(preferred_device) + else: + # Try CUDA first if available, else CPU + try: + backend_impl = target_pg._get_backend(torch.device("cuda")) + except Exception: + backend_impl = target_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + f"Cannot access device backend for process group '{group_name}'. " + f"Ensure the process group was initialized with a compatible device backend and devices are available." + ) from e + + try: + supports = bool(backend_impl.supports_shrinking) + except Exception: + supports = False + if not supports: + raise TypeError( + f"Process group backend for '{group_name}' does not support shrinking operations." + ) + + return backend_impl + + +def _validate_and_process_excluded_ranks( + ranks_to_exclude: list[int], group_info: dict +) -> set: + """Validate excluded ranks and convert to set for efficient operations.""" + group_size = group_info["group_size"] + current_rank = group_info["current_rank"] + + # Use set for O(1) duplicate detection and membership testing + excluded_ranks_set = set() + + # Validate each rank with detailed error messages + for i, rank in enumerate(ranks_to_exclude): + if not isinstance(rank, int): + raise TypeError( + f"All elements in ranks_to_exclude must be integers. " + f"Element at index {i} is {type(rank).__name__}: {rank}" + ) + + if not (0 <= rank < group_size): + raise ValueError( + f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " + f"Valid ranks are in range [0, {group_size - 1}]." + ) + + if rank in excluded_ranks_set: + raise ValueError( + f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " + f"Each rank can only be excluded once." + ) + + excluded_ranks_set.add(rank) + + # Ensure we don't exclude all ranks + if len(excluded_ranks_set) >= group_size: + raise ValueError( + f"Cannot exclude all {group_size} ranks from process group. " + f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." + ) + + # Critical check: current rank should not be in excluded list + if current_rank in excluded_ranks_set: + raise RuntimeError( + f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " + f"Only non-excluded ranks should participate in the shrinking operation. " + f"Excluded ranks should terminate their processes instead." + ) + + return excluded_ranks_set + + +def _finalize_shrunk_group( + group_info: dict, excluded_ranks_set: set, new_backend +) -> ProcessGroup: + """Clean up old group and create new shrunk process group.""" + target_pg = group_info["process_group"] + is_default_group = group_info["is_default_group"] + + # Handle default group dependencies - destroy other groups first + if is_default_group: + _destroy_all_other_groups(exclude_group=target_pg) + + # Gather original group metadata before cleanup + original_group_metadata = _extract_group_metadata(target_pg) + + # Calculate remaining ranks efficiently + original_ranks = get_process_group_ranks(target_pg) + remaining_ranks = [ + rank for rank in original_ranks if rank not in excluded_ranks_set + ] + + # Clean up the original group + _cleanup_original_group(target_pg, is_default_group) + + # Create and configure the new process group + new_pg = _create_shrunk_process_group( + new_backend, remaining_ranks, original_group_metadata, is_default_group + ) + + # Register the new group in global state + if is_default_group: + _update_default_pg(new_pg) + + # Update global state with new group information + rank_mapping = { + global_rank: group_rank + for group_rank, global_rank in enumerate(remaining_ranks) + } + _update_process_group_global_state( + pg=new_pg, + backend_name=original_group_metadata["backend_name"], + store=original_group_metadata["store"], + group_name=original_group_metadata["new_group_name"], + backend_config=original_group_metadata["backend_config"], + rank_mapping=rank_mapping, + ) + + return new_pg + + +def _extract_group_metadata(target_pg: ProcessGroup) -> dict: + """Extract metadata from the original group before cleanup.""" + original_backend_name, original_store = _world.pg_map[target_pg] + original_backend_config = _world.pg_backend_config.get(target_pg, "") + original_group_name = _get_process_group_name(target_pg) + + # Extract device binding information before cleanup to avoid accessing destroyed group + bound_device_id = None + if hasattr(target_pg, "bound_device_id"): + bound_device_id = target_pg.bound_device_id + + # Generate new group name for the shrunk group; hash for uniqueness across backends + remaining_ranks = list(get_process_group_ranks(target_pg)) + new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) + + return { + "backend_name": original_backend_name, + "store": original_store, + "backend_config": original_backend_config, + "original_group_name": original_group_name, + "new_group_name": new_group_name, + "bound_device_id": bound_device_id, # Safe to access after cleanup + } + + +def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: + """Clean up the original process group safely.""" + try: + destroy_process_group(target_pg) + except Exception: + group_type = "default" if is_default_group else "non-default" + logger.warning( + "Failed to destroy %s group during shrinking", group_type, exc_info=True + ) + + # Ensure global state cleanup even if destroy_process_group fails + _cleanup_process_group_global_state(target_pg) + + +def _create_shrunk_process_group( + new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool +) -> ProcessGroup: + """Create and configure the new shrunk process group.""" + # Create new group properties + new_group_rank = new_backend.rank() + new_group_size = new_backend.size() + group_name = metadata["new_group_name"] + + # Generate descriptive group description + if is_default_group: + group_desc = "default:shrunken" + else: + group_desc = f"{metadata['original_group_name']}:shrunk" + + # Create process group with new communicator (clone the parent store like split does) + prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) + new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) + + # Configure backend using the device type of the new backend's bound device if available, + # otherwise derive from the original group's bound device or fall back to CPU. + backend_device = metadata.get("bound_device_id") + if backend_device is None: + # Default to CPU if no bound device is present + backend_device = torch.device("cpu") + + # Choose backend enum based on device type + if backend_device.type == "cuda": + backend_type = ProcessGroup.BackendType.NCCL + else: + backend_type = ProcessGroup.BackendType.GLOO + + new_pg._register_backend(backend_device, backend_type, new_backend) + new_pg._set_default_backend(backend_type) + + # Inherit device binding from original group if it was bound + bound_device_id = metadata.get("bound_device_id") + if bound_device_id is not None: + new_pg.bound_device_id = bound_device_id + + # Set group metadata + new_pg._set_group_name(group_name) + new_pg._set_group_desc(group_desc) + + # Persist backend configuration overrides (if provided via shrink_group) + backend_config_override = metadata.get("backend_config") + if backend_config_override is not None: + # Store for introspection/debugging and potential backend hooks + _world.pg_backend_config[new_pg] = backend_config_override + + return new_pg + + +def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: + """ + Destroy all process groups except the excluded group and clean up all global state. + + This is necessary when shrinking the default group because global ranks + are reassigned by NCCL, making all existing process groups inconsistent. + + Note: Uses abort for non-collective cleanup since excluded ranks may not + participate in collective operations. Backend cleanup is handled independently per group. + + Args: + exclude_group (ProcessGroup, optional): Process group to exclude from destruction. + If None, destroys all process groups. + """ + # Get list of groups to destroy (avoid modifying dict while iterating) + groups_to_destroy = [] + for pg in list(_world.pg_group_ranks.keys()): + if exclude_group is not None and pg == exclude_group: + continue + groups_to_destroy.append(pg) + + # Warn user about automatic destruction + if groups_to_destroy: + group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] + logger.warning( + "Shrinking default group will destroy %d other process groups: %s. " + "This is necessary because shrinking the default group reassigns global ranks, " + "making existing groups inconsistent.", + len(groups_to_destroy), + ", ".join(group_names), + ) + + # Destroy each group and clean up global state + for pg in groups_to_destroy: + try: + # First call abort_process_group which handles the C++ cleanup non-collectively + _abort_process_group(pg) + except Exception: + # Log but don't fail - some groups might already be destroyed + logger.warning( + "Failed to abort process group %s", + _get_process_group_name(pg), + exc_info=True, + ) + + # Ensure all global state is cleaned up even if _abort_process_group fails + # or doesn't clean up everything + _cleanup_process_group_global_state(pg) + + +def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: + """ + Clean up all global state associated with a process group. + + This function ensures complete cleanup of process group state from all + global dictionaries and registries, even if destroy_process_group fails + or doesn't clean up everything. This is critical when destroying multiple + groups to prevent inconsistent state. + + The cleanup removes the process group from: + - _world.pg_map (backend and store mapping) + - _world.pg_names (group name mapping) + - _world.pg_group_ranks (rank mappings) + - _world.pg_backend_config (backend configuration) + - _world.tags_to_pg and _world.pg_to_tag (tag mappings) + - _world.pg_coalesce_state (coalescing state) + - C++ internal registries via _unregister_process_group + + Args: + pg (ProcessGroup): The process group to clean up. + """ + try: + # Clean up main process group mappings + _world.pg_map.pop(pg, None) + _world.pg_group_ranks.pop(pg, None) + _world.pg_backend_config.pop(pg, None) + + # Clean up process group name mapping + group_name = _world.pg_names.pop(pg, None) + + # Clean up tag mappings + pg_tag = _world.pg_to_tag.pop(pg, None) + if pg_tag is not None and pg_tag in _world.tags_to_pg: + try: + _world.tags_to_pg[pg_tag].remove(pg) + # Remove the tag entry if list is empty + if not _world.tags_to_pg[pg_tag]: + _world.tags_to_pg.pop(pg_tag, None) + except (ValueError, KeyError): + # Process group was already removed from the list + pass + + # Clean up any registered process group names using C++ unregister function + if group_name is not None: + try: + _unregister_process_group(group_name) + except Exception: + # Process group name might not be registered or already unregistered + pass + + # Clean up coalesce state if present + _world.pg_coalesce_state.pop(pg, None) + + except Exception: + # Log cleanup failures but don't propagate - we want to continue with other cleanups + logger.warning( + "Failed to fully clean up global state for process group", exc_info=True + ) + + +def _update_process_group_global_state( + pg: ProcessGroup, + backend_name: str, + store: Store, + group_name: str, + backend_config: str, + rank_mapping: Optional[dict[int, int]] = None, + pg_tag: Optional[str] = None, + user_tag: Optional[str] = None, +) -> None: + """ + Update all global state dictionaries for a process group. + + This helper function consolidates the common pattern of updating multiple + global state dictionaries when creating or modifying process groups. + + Args: + pg (ProcessGroup): The process group to update state for. + backend_name (str): Backend name for pg_map. + store (Store): Store instance for pg_map. + group_name (str): Group name for pg_names and registration. + backend_config (str): Backend configuration string. + rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. + If None, skips updating pg_group_ranks. + pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". + user_tag (str, optional): User-provided tag for special tag handling. + If provided, creates "user:{user_tag}" tag and also adds to default "". + """ + # Update main process group mappings + _world.pg_map[pg] = (backend_name, store) + _world.pg_names[pg] = group_name + _world.pg_backend_config[pg] = backend_config + + # Register the process group name + _register_process_group(group_name, pg) + + # Update rank mapping if provided + if rank_mapping is not None: + _world.pg_group_ranks[pg] = rank_mapping + + # Handle tag management + if pg_tag is None: + pg_tag = f"ptd:{group_name}" + + if user_tag is not None: + # Special handling for user-provided tags + # Add to default "" tag first + _world.tags_to_pg.setdefault("", []).append(pg) + # Then create user-specific tag + user_pg_tag = f"user:{user_tag}" + _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) + _world.pg_to_tag[pg] = user_pg_tag + else: + # Standard process group tag + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 18384b311b936..91f09adf9e816 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -238,6 +238,47 @@ def wrapper(*args, **kwargs): return decorator +def requires_world_size(n: int): + """ + Decorator to request a specific world size for a test. The test harness can + read this attribute to set the number of ranks to spawn. If there are fewer + than `n` CUDA devices available, the test should be skipped by the harness. + + Usage: + @require_world_size(3) + def test_something(self): + ... + """ + + def decorator(func): + func._required_world_size = n + available = torch.cuda.device_count() + return unittest.skipUnless( + available >= n, f"requires {n} GPUs, found {available}" + )(func) + + return decorator + + +def get_required_world_size(obj: Any, default: int) -> int: + """ + Returns the requested world size for the currently running unittest method on `obj` + if annotated via `@require_world_size(n)`, else returns `default`. + """ + try: + # Try MultiProcessTestCase helper first, then unittest fallback + test_name = ( + obj._current_test_name() # type: ignore[attr-defined] + if hasattr(obj, "_current_test_name") and callable(obj._current_test_name) + else obj._testMethodName + ) + fn = getattr(obj, test_name) + value = fn._required_world_size + return int(value) + except Exception: + return default + + # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): @@ -367,6 +408,13 @@ def requires_nccl_version(version, msg): ) +def requires_nccl_shrink(): + """ + Require NCCL shrink support (NCCL available and version >= 2.27). + """ + return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group") + + def requires_nccl(): return skip_but_pass_in_sandcastle_if( not c10d.is_nccl_available(), From 45da6e1fe17dc7fb4d96526e907ed9c9bf002f70 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Wed, 5 Nov 2025 01:02:53 +0000 Subject: [PATCH 040/651] [CD] Upload XPU inductor benchmark test reports to s3 (#166954) As the title Pull Request resolved: https://github.com/pytorch/pytorch/pull/166954 Approved by: https://github.com/atalman --- .github/workflows/_xpu-test.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index e68bc6ead3a26..d27325b8a63dc 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -344,5 +344,21 @@ jobs: if-no-files-found: ignore path: ./**/core.[1-9]* + - name: Authenticate with AWS + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results + # The max duration enforced by the server side + role-duration-seconds: 18000 + aws-region: us-east-1 + + - name: Upload the benchmark results + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: test/test-reports + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Teardown XPU uses: ./.github/actions/teardown-xpu From 64ae31c5d36255d16c832204acb7709e0762f1b3 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 4 Nov 2025 10:58:04 -0800 Subject: [PATCH 041/651] [HOP][print] Add HOP subclass for printing (#166660) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166660 Approved by: https://github.com/angelayi, https://github.com/anijain2305 Co-authored-by: Angela Yi --- test/higher_order_ops/test_print.py | 44 +++++++++++++++++++++++++++++ torch/_higher_order_ops/__init__.py | 2 ++ torch/_higher_order_ops/print.py | 44 +++++++++++++++++++++++++++++ torch/testing/_internal/hop_db.py | 25 +++++++++++----- 4 files changed, 108 insertions(+), 7 deletions(-) create mode 100644 test/higher_order_ops/test_print.py create mode 100644 torch/_higher_order_ops/print.py diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py new file mode 100644 index 0000000000000..aef538854864f --- /dev/null +++ b/test/higher_order_ops/test_print.py @@ -0,0 +1,44 @@ +# Owner(s): ["module: higher order operators"] +import io +from unittest.mock import patch + +import torch +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestHopPrint(TestCase): + def test_base_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo") + x = x * x + torch._higher_order_ops.print("moo") + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo\nmoo") + + def test_para_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo {x} {y}", x=1, y=2) + x = x * x + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo 1 2") + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 516d58bdf314e..452a080570ebe 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -24,6 +24,7 @@ from torch._higher_order_ops.local_map import local_map_hop from torch._higher_order_ops.map import map from torch._higher_order_ops.out_dtype import out_dtype +from torch._higher_order_ops.print import print from torch._higher_order_ops.run_const_graph import run_const_graph from torch._higher_order_ops.scan import scan from torch._higher_order_ops.strict_mode import strict_mode @@ -75,4 +76,5 @@ "map", "while_loop_stack_output", "local_map_hop", + "print", ] diff --git a/torch/_higher_order_ops/print.py b/torch/_higher_order_ops/print.py new file mode 100644 index 0000000000000..5a14ef23aa24e --- /dev/null +++ b/torch/_higher_order_ops/print.py @@ -0,0 +1,44 @@ +import builtins + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator + + +class Print(HigherOrderOperator): + """ + print(format_str, **kwargs) -> None + + This Higher Order Operator (HOP) provides a functional version of print for use in PyTorch graphs. + It enables format printing with named arguments, e.g., torch._higher_order_ops.print("moo {x} {y}", x=1, y=2). + + This HOP enables printing without causing graph break. + """ + + def __init__(self) -> None: + super().__init__("print") + + def __call__(self, format_str: str, **kwargs: object) -> object: + assert isinstance(format_str, str) + return super().__call__(format_str, **kwargs) + + +print = Print() + + +@print.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +# pyre-ignore +def print_cpu(format_str: str, **kwargs: object) -> None: + # Ensure all immutable_dict/list in kwargs are converted to regular dict/list + map_types: dict[type, type] = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_kwargs = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + kwargs, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + # Use built-in print to avoid recursion with the HOP print + builtins.print(format_str.format(**new_kwargs)) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index fc6cfa8cf7f4e..3b38661c69b8c 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -103,6 +103,7 @@ def f2(x, y0, y1): "dynamo_bypassing_wrapper", # TODO(soulitzer) "foreach_map", "aoti_call_delegate", + "print", ] torch.library.define( @@ -153,6 +154,7 @@ def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs def fn_for_invoke_subgraph(x): return torch.sin(x) + def simple_invoke_subgraph(x): return fn_for_invoke_subgraph(x) @@ -202,6 +204,7 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) + def simple_while_loop_stack_output(iter_t, x): def cond_fn(iter_t, x): return iter_t > 0 @@ -209,7 +212,9 @@ def cond_fn(iter_t, x): def body_fn(iter_t, x): return iter_t - 1, x.cos() - return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple()) + return torch._higher_order_ops.while_loop_stack_output( + cond_fn, body_fn, (iter_t, x), tuple() + ) def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): @@ -226,18 +231,21 @@ def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): def simple_local_map_hop(inp1, inp2): def body_gm(inp1, inp2): return inp1.cos() + inp2.sin() + gm = torch.fx.symbolic_trace(body_gm) assert torch.distributed.is_available() from torch.distributed.tensor.placement_types import Replicate + gm.meta["local_map_kwargs"] = { "in_placements": (Replicate(), Replicate(), Replicate()), - "out_placements": ((Replicate(), Replicate(), Replicate()),) + "out_placements": ((Replicate(), Replicate(), Replicate()),), } # TODO: Dynamo would rewrite this op differently return torch._higher_order_ops.local_map_hop(gm, inp1, inp2) + def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -249,7 +257,6 @@ def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): def simple_scan(init, xs): - def combine_fn(carry, x): result = carry @ x + x return result, carry.clone() @@ -264,15 +271,14 @@ def simple_invoke_quant(x): def fn(x, y): return (torch.sin(x) * y,) - return quant_tracer(fn, x, x)[0] * 2. + return quant_tracer(fn, x, x)[0] * 2.0 def simple_invoke_quant_packed(x): def fn(x): return (torch.sin(x),) - return invoke_quant_packed(fn, x)[0] * 2. - + return invoke_quant_packed(fn, x)[0] * 2.0 hop_db = [ @@ -496,6 +502,11 @@ def fn(x): DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), ), - decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")], + decorators=[ + onlyCUDA, + unittest.skipIf( + not torch.distributed.is_available(), "requires distributed build" + ), + ], ), ] From bcd159bcddf477fe38fd020af403f7d1004c6c2b Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Wed, 5 Nov 2025 01:16:54 +0000 Subject: [PATCH 042/651] Fix the vmap op fallback bug (#166032) ## The bug In some environments, if run: ```py def inner_func(x): return x.to(torch.float32, memory_format=torch.channels_last) x = torch.randn(2, 2, 3, 4, device="cpu", dtype=torch.float64) torch.vmap(inner_func)(x) ``` we get: ``` E RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops. ``` Otherwise, it would always fallback and result in an error for ops like `to.dtype` and `to.dtype_layout` even the kernels are registered. ## The cause The alias key of `FuncTorchBatchedDecomposition` is not properly translated to runtime dispatch keys when updating the dispatch table of `OperatorEntry::dispatchTable_`. [[link](https://github.com/pytorch/pytorch/blob/984b096d10398a615a791fd11296d6d51fdd55a4/aten/src/ATen/core/dispatch/OperatorEntry.cpp#L500-L501)] The [`getRuntimeDispatchKeySet`](https://github.com/pytorch/pytorch/blob/f3fa560dec727380b3e9c074efe05f0ce715a5ca/c10/core/DispatchKeySet.cpp#L62) use if-else to translate all other alias keys but `FuncTorchBatchedDecomposition`. This would result in not finding the kernel in many cases. ## The fix This PR adds one more `if` statement to `getRuntimeDispatchKeySet` to map `FuncTorchBatchedDecomposition` to the corresponding runtime dispatch key, `FuncTorchBatched`. So, that the dispatch table can be properly updated. This fix allows people to use ops inside vmaps in more environments and across more compilers. ## Why does it work without the PR As long as the `FuncTorchBatchedDecomposition` [[link](https://github.com/pytorch/pytorch/blob/51319ca090bc7458168a8451c04ca7e021a72693/aten/src/ATen/functorch/BatchRulesDecompositions.cpp#L35)] is registered before the fallback method of `FuncTorchBatched` [[link](https://github.com/pytorch/pytorch/blob/d311a3d1dca4bfc01ae44fc5d1f8d7ff22bc551f/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp#L759)], everything runs fine. In this case, it relies on the registration of the fallback method to update the dispatch table, which flushes all the kernels in `OperatorEntry::kernels_` into `dispatchTable_`, among which there are kernels registered with `FuncTorchBatchedDecomposition`. ## When does it fail However, the order of the op registration and the fallback registration is not garanteed at all. It relies on the C++ static initialization order, which varies from environment to environment. On our compiler, it the fallback registration goes first and the alias key kernels under `FuncTorchBatchedDecomposition` comes later and not get flushed into the dispatch table by the fallback registration. Therefore, it cannot find the kernel for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166032 Approved by: https://github.com/albanD --- c10/core/DispatchKeySet.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 72e72f49a5e40..107530e9e28a2 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -59,6 +59,9 @@ constexpr DispatchKeySet nested_dispatch_keyset = {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); +constexpr DispatchKeySet functorch_batched_dispatch_keyset = + DispatchKeySet(DispatchKey::FuncTorchBatched); + DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { @@ -77,6 +80,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { return backend_dispatch_keyset; case DispatchKey::CompositeExplicitAutogradNonFunctional: return non_functional_backend_dispatch_keyset; + case DispatchKey::FuncTorchBatchedDecomposition: + return functorch_batched_dispatch_keyset; default: return DispatchKeySet(t); } From 01e6e35c7faf913c3a85c7a64d2939cfa768358a Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Tue, 4 Nov 2025 13:45:20 -0800 Subject: [PATCH 043/651] Send / recv support in local tensor (#166595) This change introduces LocalRunnerMode that allows you to run multiple SPMD functions concurrently. SMPD functions are executing one at a time, yielding execution capability while waiting for send or receive operations to complete. Send and receive peer operations only supported while running under LocalRunnerMode. The example test in this change demonstrates how ranks are sending data to the next peer and receiving data from the previous peer (ring). Pull Request resolved: https://github.com/pytorch/pytorch/pull/166595 Approved by: https://github.com/wconstab, https://github.com/ezyang --- build_variables.bzl | 1 + test/distributed/test_local_tensor.py | 77 ++++++++++ torch/_C/_distributed_c10d.pyi | 7 +- torch/csrc/distributed/c10d/init.cpp | 28 ++++ .../distributed/c10d/python_callback_work.cpp | 64 +++++++++ .../distributed/c10d/python_callback_work.hpp | 28 ++++ torch/distributed/_local_tensor/__init__.py | 134 ++++++++++++++++++ torch/distributed/_local_tensor/_c10d.py | 104 ++++++++++++-- .../distributed/_tensor/common_dtensor.py | 4 + 9 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 torch/csrc/distributed/c10d/python_callback_work.cpp create mode 100644 torch/csrc/distributed/c10d/python_callback_work.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 70121e19d8099..258e739300c1e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1025,6 +1025,7 @@ libtorch_python_core_sources = [ libtorch_python_distributed_core_sources = [ "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/python_comm_hook.cpp", + "torch/csrc/distributed/c10d/python_callback_work.cpp", ] libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [ diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index 114780627e334..c58ddf0f82ba7 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -7,6 +7,8 @@ import torch.distributed as dist from torch.distributed._local_tensor import ( local_tensor_mode, + LocalIntNode, + LocalRunnerMode, LocalTensor, LocalTensorMode, ) @@ -17,8 +19,10 @@ Partial, Replicate, Shard, + zeros, ) from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import reduce_local_int class LocalTensorTestBase(TestCase): @@ -411,5 +415,78 @@ def test_dtensor_addmm(self): self.assertEqual(full_tensor, local_res) +from torch.distributed._local_tensor._c10d import local_p2p_op, wait_all + + +class TestLocalRunner(LocalTensorTestBase): + world_size = 6 + + @staticmethod + def _get_pp_peer(pp_index, mesh, dim, dir): + pp_meshes = mesh._get_all_submeshes(dim) + pp_ret = {} + for pp_mesh in pp_meshes: + global_rank = pp_mesh.mesh[pp_index].item() + global_peer = pp_mesh.mesh[(pp_index + dir) % pp_mesh.size()].item() + pp_ret[global_rank] = global_peer + + return torch.SymInt(LocalIntNode(pp_ret)) + + def _run_dp_pp( + self, + mesh: DeviceMesh, + pp_index: int, + actual: list[torch.Tensor | None], + expected: list[torch.Tensor | None], + ) -> None: + ltm = LocalTensorMode(mesh.size()) + with ltm: + dp_mesh = mesh["dp"] + pp_mesh = mesh["pp"] + + x = torch.rand(2, 4) + xd = distribute_tensor(x, dp_mesh, [Shard(0)]) + xd = xd * 2 + x = x * 2 + + yd = zeros(*xd.shape, device_mesh=dp_mesh, placements=[Shard(0)]) + + if pp_index != pp_mesh.size(0) - 1: + # Send to next pp rank + pp_next_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", +1) + local_p2p_op(pp_next_rank, xd, dist.isend) + expected[pp_index + 1] = ltm.tensor_map( + x, + lambda r, t: t + if reduce_local_int(pp_next_rank, lambda vals: r in vals.values()) + else torch.zeros_like(t), + ) + + if pp_index != 0: + # Receive from prev pp rank + pp_prev_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", -1) + rw = local_p2p_op(pp_prev_rank, yd, dist.irecv) + wait_all(rw) + + y = yd.full_tensor() + actual[pp_index] = y + + def test_dp_pp(self): + pp_size = 3 + mesh = init_device_mesh( + "cpu", (self.world_size // pp_size, pp_size), mesh_dim_names=("dp", "pp") + ) + actual: list[torch.Tensor | None] = [None] * pp_size + expected: list[torch.Tensor | None] = [None] * pp_size + with LocalRunnerMode( + self.world_size, + pp_size, + lambda pp_index: self._run_dp_pp(mesh, pp_index, actual, expected), + ): + pass + + self.assertEqual(actual, expected) + + if __name__ == "__main__": run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 737362be62b48..f3d96860f5584 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, Optional, overload, Union +from typing import Any, Callable, Optional, overload, Union import torch from torch import Tensor @@ -616,6 +616,11 @@ class FakeWork(Work): def wait(self, timeout: timedelta = ...) -> bool: ... def getFuture(self) -> Future: ... +class PythonCallbackWork(Work): + def __init__(self, callback: Callable[[timedelta], bool]) -> None: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def get_future(self) -> Future: ... + class ProcessGroupGloo(Backend): class Device: ... diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4c6bdbe2ce70f..91bb3469e3e85 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #ifdef USE_C10D_GLOO #include @@ -3887,6 +3888,33 @@ such as `dist.all_reduce(tensor, async_op=True)`. .def("wait", &::c10d::FakeWork::wait, py::arg("timeout") = kNoTimeout) .def("getFuture", &::c10d::FakeWork::getFuture); + auto pythonCallbackWork = + intrusive_ptr_no_gil_destructor_class_<::c10d::PythonCallbackWork>( + module, "PythonCallbackWork", work) + .def(py::init(), py::arg("callback")) + .def( + "wait", + &::c10d::PythonCallbackWork::wait, + py::arg("timeout") = kNoTimeout, + R"( + Waits until the callback completes. Blocking operation. + The callback is invoked with the timeout parameter and should return a boolean. + Throws if the callback completes with an exception. + Returns the boolean value returned by the callback. + )") + .def( + "get_future", + [](::c10d::PythonCallbackWork& work) + -> std::shared_ptr { + return std::make_shared( + work.getFuture()); + }, + R"( + Returns: + A ``torch.futures.Future`` object which is associated with the completion of + the ``PythonCallbackWork``. + )"); + py::class_(module, "DDPLoggingData") .def(py::init<>()) .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map) diff --git a/torch/csrc/distributed/c10d/python_callback_work.cpp b/torch/csrc/distributed/c10d/python_callback_work.cpp new file mode 100644 index 0000000000000..47bef1831a480 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.cpp @@ -0,0 +1,64 @@ +#include +#include + +namespace c10d { + +PythonCallbackWork::PythonCallbackWork(py::function callback) + : callback_(std::move(callback)) { + // Create a future that will be marked as complete when wait() is called + future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); +} + +// NOLINTNEXTLINE(bugprone-exception-escape) +PythonCallbackWork::~PythonCallbackWork() { + py::gil_scoped_acquire ag; + callback_.dec_ref(); + // Explicitly set callback_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + callback_.ptr() = nullptr; +} + +bool PythonCallbackWork::wait(std::chrono::milliseconds timeout) { + py::gil_scoped_acquire ag; + + try { + // Call the Python callback with timeout + py::object result = callback_(timeout); + + // Extract the boolean result + bool success = result.cast(); + + // Mark the work as completed if successful + if (success) { + finish(); + // Mark the future as complete with an empty list + if (!future_->completed()) { + future_->markCompleted(c10::IValue(c10::List())); + } + } + + return success; + } catch (py::error_already_set& e) { + // Capture the Python exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } catch (const std::exception& e) { + // Capture any C++ exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } +} + +c10::intrusive_ptr PythonCallbackWork::getFuture() { + return future_; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_callback_work.hpp b/torch/csrc/distributed/c10d/python_callback_work.hpp new file mode 100644 index 0000000000000..48966e785ad60 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace c10d { + +// PythonCallbackWork is a subclass of Work that wraps a Python callback +// function that implements wait(). This allows asynchronous work to +// be integrated with Python code, enabling custom completion logic or +// post-processing in Python. +class PythonCallbackWork : public Work { + public: + explicit PythonCallbackWork(py::function callback); + + ~PythonCallbackWork() override; + + bool wait(std::chrono::milliseconds timeout) override; + + c10::intrusive_ptr getFuture() override; + + private: + py::function callback_; + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index ea9707b2e1e85..c186694df94e7 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -64,6 +64,7 @@ np = None # type: ignore[assignment] import torch +import torch.distributed as dist from torch import Size, SymBool, SymInt, Tensor from torch._C import DispatchKey, DispatchKeySet, ScriptObject from torch._export.wrappers import mark_subclass_constructor_exportable_experimental @@ -921,6 +922,22 @@ def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor({r: cb(r) for r in self.ranks}) + def tensor_map( + self, tensor: LocalTensor, cb: Callable[[int, Tensor], Tensor | None] + ) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + results = {} + for r in self.ranks: + if r in tensor._local_tensors: + m = cb(r, tensor._local_tensors[r]) + if m is not None: + results[r] = m + return LocalTensor(results) + def _patch_device_mesh(self) -> None: assert self._old_get_coordinate is None self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] @@ -1049,3 +1066,120 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: """ lm = local_tensor_mode() return lm.disable() if lm is not None else contextlib.nullcontext() + + +import threading +from queue import Queue + + +_LOCAL_RUNNER_MODE: "LocalRunnerMode | None" = None + + +class LocalRunnerMode: + """ + A class for running multiple SPMD functions concurrently, however at any point + in time only one function can be running. The main use case for the local runner + mode is to enable SPMD functions to be able to use send and recv to communicate + with each other. Without local runner mode send and recv are not supported. + """ + + runner_context = threading.local() + + def __init__( + self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None] + ): + if isinstance(ranks, int): + ranks = frozenset(range(ranks)) + self._ranks = ranks + self._fn = fn + self._run_lock = threading.Lock() + self._run_id = -1 + self._run_cond = threading.Condition(self._run_lock) + + self._recv_objects: dict[int, dict[int, Queue]] = { + dst: {src: Queue() for src in ranks} for dst in ranks + } + self._runners = [ + threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") + for i in range(concurrency) + ] + + def __enter__(self) -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" + _LOCAL_RUNNER_MODE = self + + for r in self._runners: + r.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + for r in self._runners: + r.join() + global _LOCAL_RUNNER_MODE + _LOCAL_RUNNER_MODE = None + + def _run(self, id: int) -> None: + LocalRunnerMode.runner_context.id = id + # Only one thread can run at a time, hence must acquire the lock + try: + self._acquire_run_lock() + self._fn(id) + finally: + self._release_run_lock() + + def _acquire_run_lock(self) -> None: + self._run_lock.acquire() + self._run_id = LocalRunnerMode.runner_context.id + + def _release_run_lock(self) -> None: + self._run_id = -1 + self._run_lock.release() + + def _assert_holds_run_lock(self) -> None: + assert self._run_id == LocalRunnerMode.runner_context.id, ( + "Calling thread does not hold the run lock" + ) + + def _get_recv_object(self, src: int, dst: int) -> object | None: + peers = [src] if src != -1 else list(self._ranks) + recv_objects = self._recv_objects[dst] + + for p in peers: + if not recv_objects[p].empty(): + return recv_objects[p].get() + + return None + + def _signal_send(self, src: int, dst: int, obj: object) -> None: + assert obj is not None, "Cannot signal None" + self._assert_holds_run_lock() + # Only a single thread a time executes so it is safe to mutate + # read objects queue (executing thread is already holding the lock) + self._recv_objects[dst][src].put(obj) + # Signal directly condition variable since the calling thread is already + # holding the lock + self._run_cond.notify_all() + + def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: + self._assert_holds_run_lock() + # Wait for the object to be available + while True: + obj = self._get_recv_object(src, dst) + if obj is not None: + post(obj) + # Note that we are not releasing the lock here, since the thread + # will continue to run and therefore must hold the lock + return + self._run_cond.wait() + + @staticmethod + def current() -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" + return _LOCAL_RUNNER_MODE diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 30b99931f2514..c9256543e8977 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -2,12 +2,15 @@ import math import operator from collections.abc import Sequence +from datetime import timedelta +from typing import Callable import torch from torch._C import ScriptObject -from torch._C._distributed_c10d import FakeWork +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork from torch.distributed._mesh_layout import _MeshLayout from torch.distributed.distributed_c10d import ( + _check_op, _get_default_group, _resolve_process_group, ProcessGroup, @@ -765,10 +768,19 @@ def _local_send( # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int dst, int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like send. " - "Use SPMD collective operations instead." - ) + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + src = int(tensor.__src_rank__) + + LocalRunnerMode.current()._signal_send(src, dst, tensor._local_tensors[src]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so def _local_recv_( @@ -779,11 +791,26 @@ def _local_recv_( ) -> ScriptObject: # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int src, int tag) -> __torch__.torch.classes.c10d.Work"; + from . import LocalRunnerMode, LocalTensor - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv. " - "Use SPMD collective operations instead." - ) + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + dst = int(tensor.__src_rank__) + + def _recv_and_store(timeout: timedelta) -> bool: + def _wait_and_store(obj: object) -> None: + assert isinstance(obj, torch.Tensor), "Expected to receive a Tensor" + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + tensor._local_tensors[dst] = obj + + LocalRunnerMode.current()._wait_recv(src, dst, _wait_and_store) + return True + + work = PythonCallbackWork(_recv_and_store) + work_so = Work.boxed(work) + return work_so def _local_recv_any_source_( @@ -792,7 +819,60 @@ def _local_recv_any_source_( # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv_any_source. " - "Use SPMD collective operations instead." + return _local_recv_(tensors, process_group_so, -1, tag) + + +def _attach_rank(tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Attaches rank as an attribute to given tensor so that the send or recv implementation + knows which rank initiates the operation (note under local tensor mode ). + """ + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + + tensor.__src_rank__ = rank # type: ignore[attr-defined] + return tensor + + +def local_p2p_op( + dst: torch.SymInt, + tensor: torch.Tensor, + op: Callable[[torch.Tensor, int], Work | None], +) -> Work | None | list[Work | None]: + """ + Runs a point-to-point (P2P) operation for all combinations of source and destination ranks. + """ + _check_op(op) + + from . import LocalIntNode + + assert isinstance(dst.node, LocalIntNode), ( + "Expected 'dst' to be a LocalIntNode where the value is the destination rank and key is the source rank" ) + + w = [] + for s, d in dst.node._local_ints.items(): + tensor = _attach_rank(tensor, s) + w.append(op(tensor, d)) + return w + + +def wait_all(work: Work | None | list[Work | None]) -> None: + """ + Waits for all work objects in the input to complete. + + A single Work object, None, or a list of Work objects (possibly containing None). + If None, does nothing. If a single Work, waits for it to complete. If a list, waits + for each non-None Work in the list to complete. + """ + + if work is None: + return + if isinstance(work, Work): + work = [work] + for w in work: + if w is None: + continue + w.wait() diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 17140f40684dd..f4afca4bd1803 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -814,3 +814,7 @@ def map_local_tensor_for_rank(tensor, rank, func): @maybe_run_for_local_tensor def map_local_for_rank(rank, func): return func(rank) + + +def reduce_local_int(val, func): + return func(val.node._local_ints) From cd5d810c3aefa6bcc6a71e849e5eaa9db90f8d35 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 5 Nov 2025 02:22:29 +0000 Subject: [PATCH 044/651] Annotation should be deepcopied (#167017) The annotation should be deepcopied. Otherwise all nodes with the same `seq_nr` share the same underlying dict Pull Request resolved: https://github.com/pytorch/pytorch/pull/167017 Approved by: https://github.com/yiming0416 --- torch/_functorch/_aot_autograd/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 844f34bb576da..9fbb5e5fe9841 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -3,6 +3,7 @@ Contains various utils for AOTAutograd, including those for handling collections. """ +import copy import dataclasses import logging import operator @@ -459,7 +460,9 @@ def _copy_metadata_to_bw_nodes_in_subgraph( node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack") node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") # TODO: better to change to a specific field of custom? - node.meta["custom"] = fwd_node.meta.get("custom") + custom = fwd_node.meta.get("custom") + if custom is not None: + node.meta["custom"] = copy.deepcopy(custom) def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: From 53b03f1a2b4c8fc0e20bdff4cfbb43aad01bb978 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 02:36:46 +0000 Subject: [PATCH 045/651] Revert "make narrow_tensor_symint DDE-free (#166379)" This reverts commit d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c. Reverted https://github.com/pytorch/pytorch/pull/166379 on behalf of https://github.com/malfet due to Need to revert previous PR in the stack ([comment](https://github.com/pytorch/pytorch/pull/166379#issuecomment-3488910172)) --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 ------------- test/test_proxy_tensor.py | 1 + 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index b3fff5a4bb42f..6136a6aa8c520 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1764,8 +1764,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - c10::SymInt st = start.item().toSymInt(); - return at::narrow_symint(self, dim, std::move(st), std::move(length)); + int64_t st = start.item(); + return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..b0dd1ff8fa75d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - skip("narrow"), + xfail("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d3f9e415ff944..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,19 +4452,6 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() - @torch._dynamo.config.patch(capture_scalar_outputs=True) - def test_narrow_with_tensor_start(self): - @torch.compile(backend="inductor", fullgraph=True) - def f(x, start, end): - return torch.narrow(x, 0, start, end) - - x = torch.tensor( - [False], device="cuda:0" if torch.cuda.is_available() else "cpu" - ) - start = torch.tensor(0) - res = f(x, start, 0) - self.assertEqual(res.shape, torch.Size([0])) - instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 0487995a2d1c5..b76895a0a91f3 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,6 +1987,7 @@ def f(t): } only_fake_tensor_failures = { + xfail('narrow'), xfail('tensor_split'), } From a743f9eeb57255f800d4c91ba29da6e0d9c4a229 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 02:39:55 +0000 Subject: [PATCH 046/651] Revert "Avoid DDE in narrow with unbacked start (#166361)" This reverts commit ed45c5f38df6aa419c67d139d932c2c94404223a. Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/malfet due to Looks like it broke test_torchfuzz subtests, see https://hud.pytorch.org/hud/pytorch/pytorch/01e6e35c7faf913c3a85c7a64d2939cfa768358a/1?per_page=50&name_filter=trunk&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3488916766)) --- aten/src/ATen/native/TensorShape.cpp | 38 +++--------------- c10/core/SymBool.cpp | 14 ------- c10/core/SymBool.h | 6 --- test/export/test_export.py | 31 +++++--------- test/test_dynamic_shapes.py | 51 ------------------------ test/test_torchfuzz_repros.py | 5 +-- torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 +-------- torch/utils/_sympy/printers.py | 36 ----------------- 9 files changed, 19 insertions(+), 184 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6136a6aa8c520..6df7761d822db 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,6 +1,5 @@ #include #include -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1711,14 +1710,11 @@ Tensor narrow_symint( "], but got ", start, ")") - // Bounds check without converting start: - // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + - // length <= 0 - // - If start >= 0: need start + length <= cur_size - auto end = start + length; + if (start < 0) { + start = start + cur_size; + } TORCH_SYM_CHECK( - (start.sym_lt(0).sym_and((end).sym_le(0))) - .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), + start.sym_le(cur_size - length), "start (", start, ") + length (", @@ -1726,31 +1722,7 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - - if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) { - return at::slice_symint(self, dim, start, end, 1); - } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) { - // Avoid the complex symbolic expressions path for non-unbacked. - return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1); - } else { - // Cannot statically determine the condition due to unbacked. - // This is an interesting situation; when start is negative and - // start + length == 0, slice and narrow do different things. - // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to - // pass curr_size instead of 0. Otherwise, they would do the same thing. - // This says at runtime: if start < 0 and end == 0, then pass curr_size - // instead of 0. - - auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); - auto result = - at::slice_symint(self, dim, start, end + use_different * cur_size, 1); - - // Ensure slice allocated unbacked size is specialized to length. - SymInt new_size = result.sym_size(dim); - TORCH_SYM_CHECK(new_size.sym_eq(length), "") - - return result; - } + return at::slice_symint(self, dim, start, start + length, 1); } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index 48c407b8b069c..d804eb9d27409 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,5 +1,4 @@ #include -#include #include namespace c10 { @@ -112,17 +111,4 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } -SymInt SymBool::toSymInt() const { - // If concrete bool, return concrete SymInt - if (auto ma = maybe_as_bool()) { - return SymInt(*ma ? 1 : 0); - } - - // Symbolic case: use sym_ite to convert bool to int (0 or 1) - auto node = toSymNodeImpl(); - auto one_node = node->wrap_int(1); - auto zero_node = node->wrap_int(0); - return SymInt(node->sym_ite(one_node, zero_node)); -} - } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index a27a28a5bf8a3..d5d509e239b1d 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,8 +12,6 @@ namespace c10 { -class SymInt; - class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -82,10 +80,6 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } - // Convert SymBool to SymInt (0 or 1) - // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless - SymInt toSymInt() const; - bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index cdc18b1d4c564..3908f03b11e55 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,19 +6093,26 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[], + fixes=[ + # Could not guard on data-dependent expression u0 < 0 + "torch._check(i >= 0)", + ], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) + # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[], + fixes=[ + # Could not guard on data-dependent expression u0 < 0 + "torch._check(i >= 0)", + ], ) class cf_tensorsplit(torch.nn.Module): @@ -6159,12 +6166,7 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - if y.item() < 0: - return ( - torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() - ) - else: - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6194,18 +6196,7 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - if box.content < 0: - return ( - torch.stack(xs, 0) - .narrow(0, box.content + xs.size(), 1) - .squeeze() - ) - else: - return ( - torch.stack(xs, 0) - .narrow(0, box.content + xs.size(), 1) - .squeeze() - ) + return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..fb1d22805d50a 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,57 +4401,6 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_narrow_unbacked_start(self): - def func(x, start, length): - # unbacked start - u0 = start.item() - return torch.narrow(x, 0, u0, length) - - compiled_func = torch.compile(func, fullgraph=True, backend="inductor") - - x = torch.tensor([1, 2, 3, 4, 5, 6]) - - # Test cases: (start, length) - test_cases = [ - # Negative starts - (-2, 2), # Start from second-to-last element - (-1, 1), # Start from last element - (-3, 3), # Start from third-to-last element - (-6, 2), # Start from beginning (negative) - (-4, 1), # Start from fourth-to-last element - # Positive starts - (0, 2), # Start from beginning - (1, 3), # Start from second element - (2, 2), # Start from third element - (4, 2), # Start near end - # Edge cases - (0, 6), # Full tensor - (0, 1), # Single element from start - (5, 1), # Single element from end - ] - - for start_val, length in test_cases: - with self.subTest(start=start_val, length=length): - start = torch.tensor([start_val]) - - # Test with compiled function - result_compiled = compiled_func(x, start, length) - - # Test with eager function (expected behavior) - result_eager = func(x, start, length) - - # Compare results - self.assertEqual(result_compiled, result_eager) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_narrow_unbacked_start_cpp_wrapper(self): - """Test narrow with unbacked start with cpp_wrapper""" - self.test_narrow_unbacked_start() - instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 84a00430420cf..3b864aae4f477 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -16,10 +16,6 @@ from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -# Skip all tests in this file if CUDA is not available -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") - - class TestFuzzerCompileIssues(TestCase): """Test cases for fuzzer-discovered eager/compile divergence issues.""" @@ -261,6 +257,7 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") + @pytest.mark.xfail(reason="Issue #163971") def test_fuzzer_issue_163971(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 947166cf216cd..e629d9c7bdebd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,8 +2063,7 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - x_cond = self.codegen_sizevar(x) - return f"{pos} if {x_cond} >= 0 else {neg}" + return f"{pos} if {x} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 693d25aea6130..aeccdfbe000db 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,7 +547,6 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) - # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -603,23 +602,7 @@ def rebind_unbacked( if u1.node.hint is not None: continue - # unbacked symbols bindings might be replaced to other backed or - # unbacked replacements. - # - # Example: - # u = x.item() - # torch._check(u == 5) - # - # The safest approach is to retrieve raw_u1 from u1.node._expr - # and perform the rebinding on the original unbacked symbol, - # even if it’s no longer directly referenced. - # - # In other words, we should always rebind the original symbol - # before any replacements are applied. - # u0 -> u0 == s1 - raw_u1 = u1.node._expr - - # TODO Do we still need this logic below? + raw_u1 = u1.node.expr # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 915d0e5461f1e..526443577b3f8 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,24 +306,6 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" - def _print_Piecewise(self, expr: sympy.Expr) -> str: - # Convert Piecewise(expr_cond_pairs) to nested ternary expressions - # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) - # becomes: e1 if c1 else (e2 if c2 else (... else eN)) - result: Optional[str] = None - for expr_i, cond_i in reversed(expr.args): - expr_str = self._print(expr_i) - if cond_i == True: # noqa: E712 - # This is the default case - result = expr_str - else: - cond_str = self._print(cond_i) - if result is None: - result = expr_str - else: - result = f"({expr_str} if {cond_str} else {result})" - return result if result else "0" - class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -345,24 +327,6 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" - def _print_Piecewise(self, expr: sympy.Expr) -> str: - # Convert Piecewise(expr_cond_pairs) to nested ternary operators - # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) - # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) - result: Optional[str] = None - for expr_i, cond_i in reversed(expr.args): - expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) - if cond_i == True: # noqa: E712 - # This is the default case - result = expr_str - else: - cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) - if result is None: - result = expr_str - else: - result = f"{cond_str} ? {expr_str} : {result}" - return f"({result})" if result else "0" - def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 03:03:41 +0000 Subject: [PATCH 047/651] [12/N] Apply ruff UP035 rule (#166929) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929 Approved by: https://github.com/Lucaskabela --- test/distributed/tensor/test_attention.py | 3 ++- test/higher_order_ops/test_local_map.py | 3 ++- test/inductor/test_caching.py | 3 ++- test/inductor/test_fx_fusion.py | 3 ++- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 ++- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 ++- torch/_dynamo/graph_bytecode_inputs.py | 3 ++- torch/_dynamo/variables/distributed.py | 3 ++- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 ++- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 ++- torch/_dynamo/variables/streams.py | 3 ++- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 ++- torch/_inductor/compile_worker/timer.py | 3 ++- torch/_inductor/fx_passes/bucketing.py | 3 ++- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +++++- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 ++- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 ++--- torch/_inductor/kernel/custom_op.py | 3 ++- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 ++- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++++-- torch/_inductor/runtime/caching/locks.py | 5 +++-- torch/distributed/elastic/multiprocessing/tail_log.py | 3 ++- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 ++- torch/utils/_pytree.py | 3 ++- 37 files changed, 76 insertions(+), 50 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index eaf3a4042060d..6c3485f9d7025 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,7 +3,8 @@ import itertools import random import unittest -from typing import Any, Callable, ClassVar, Optional +from collections.abc import Callable +from typing import Any, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 9d2870d3b5fdd..fbb21633260e7 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,8 +4,9 @@ import functools import unittest +from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index bcb66beea700c..aa4c3a1f229f1 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, Generator, Sequence, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index ebe98373e622a..63342502d3cd9 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 1870a0e373be0..c37f844e41eae 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from typing import Callable +from collections.abc import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cd922d94c60c3..faba2f5edc6a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,7 +204,8 @@ import operator import unittest import io -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 5e54a851812e0..1ba947befd9e7 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from typing import Callable +from collections.abc import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e23e049e3bbb1..222647eeae9ab 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,10 +39,11 @@ import unittest import warnings import weakref +from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 979950cf3bd1b..16583b89201ec 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,5 +1,6 @@ import weakref -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index eb39dd8fa3e07..187055c26cd00 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,7 +20,8 @@ import functools import inspect -from typing import Any, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 5970ba0e1dda7..be765cbbc8bf9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable -from typing import Any, Sequence, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 289cebbe8129b..c09cc2163a5f4 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,7 +22,8 @@ import logging import weakref -from typing import Any, Iterable, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 85977104977fb..644c269a23a34 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable -from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 75928842cf297..629bf094dc951 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from inspect import getattr_static -from typing import Any, Sequence, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index c353181eb8029..fb5dd775bd636 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,5 +1,6 @@ import collections -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index fa8412146a427..4d0f0b4fae8ab 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Sequence +from collections.abc import Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ce01e37f03243..7e608933b34c3 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,9 +22,10 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7cfeb4217e26b..7c495403b3a55 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Callable, Optional, Union +from typing import Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index ab831c96c94ba..29f070564349c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,7 +2,8 @@ import logging import operator from collections import defaultdict -from typing import Any, Callable, Literal, TypeAlias +from collections.abc import Callable +from typing import Any, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8a4de1a604869..44314b912786f 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6b0c2ad2c94a7..1e71c350ed7b6 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index c6b7c51b948e5..e887d4bf62c8e 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 70b3a3c355dde..214d3bf02f7f4 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, Callable +from typing import Any, TYPE_CHECKING import torch from torch._dynamo.utils import counters @@ -35,6 +35,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index a47aa960e58c5..f383ab63dc261 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 30768fda9bb72..b511403d4874c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7d995adec04ef..91b4e10bf7238 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,8 @@ import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 52222f3da8344..e42e8a1139770 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 92e1e6f375f44..0bad4fa7cc635 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,9 +4,8 @@ import operator import os from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 303110a561b5e..d35309c01d07c 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,7 +2,8 @@ import functools import logging -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index c100df84d5a73..0d3721aa730a4 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,8 +3,9 @@ import functools import importlib +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d592a8c8c00f9..d9d92e363879d 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Optional, Union -from typing_extensions import Concatenate, ParamSpec, Self, TypeVar +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 0758e11134018..03d2957493679 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import override, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override from filelock import FileLock @@ -21,6 +21,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index e7e1f1adc3622..8e8cd011e2d44 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import Generator, TYPE_CHECKING -from typing_extensions import Protocol, TypeAlias +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol from filelock import FileLock, Timeout @@ -21,6 +21,7 @@ if TYPE_CHECKING: + from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7ad35115cd34a..034740810dcdd 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,9 +10,10 @@ import logging import os import time +from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union +from typing import Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c12..897279bd39b1e 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..5a6ee246abf7e 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,7 +3,8 @@ import functools import traceback import weakref -from typing import Any, Callable, Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 56704bb3f8024..147340f58d66e 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,10 +36,11 @@ Optional, overload, Protocol, + TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion From 56fc99915b7e5c653c30460052644204caaddbcf Mon Sep 17 00:00:00 2001 From: Michael Klamkin Date: Wed, 5 Nov 2025 03:05:04 +0000 Subject: [PATCH 048/651] Fix typos in complex numbers docs (#166671) This PR fixes two small typos in the complex numbers docs: 1. "numbercial" -> "numerical" 2. "easily to switch" -> "easily switch to" Pull Request resolved: https://github.com/pytorch/pytorch/pull/166671 Approved by: https://github.com/jcaip, https://github.com/Arpitha781, https://github.com/mlazos, https://github.com/cyyever --- docs/source/complex_numbers.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/complex_numbers.md b/docs/source/complex_numbers.md index 610f9a06615a1..095401879f09b 100644 --- a/docs/source/complex_numbers.md +++ b/docs/source/complex_numbers.md @@ -45,7 +45,7 @@ supported for complex tensors. ## Transition from the old representation Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)` -can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex` +can easily switch to using the complex tensors in their code using {func}`torch.view_as_complex` and {func}`torch.view_as_real`. Note that these functions don’t perform any copy and return a view of the input tensor. @@ -140,7 +140,7 @@ through the same optimizer on the {func}`torch.view_as_real` equivalent of the c `real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers -and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). +and capturable vs default optimizers. For more details, see [numerical accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their `p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the From 08ef852a4b0f8cab0d35c30be33dfde812bfc6d8 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 5 Nov 2025 03:09:52 +0000 Subject: [PATCH 049/651] [unified v2][apple] Clean up `APPLETVOS` from caffe2 (#166953) Summary: This is not used, so delete it Test Plan: ``` $ buck targets xplat/... > /dev/null ``` Reviewed By: dtolnay Differential Revision: D86125712 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166953 Approved by: https://github.com/seemethere --- .../quantized/cpu/qnnpack/buckbuild.bzl | 20 +-- buckbuild.bzl | 4 +- third_party/xnnpack.buck.bzl | 114 +++++++++--------- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl index 180442b4b09a4..fecce634ec08c 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX") # Shared by internal and OSS BUCK def define_qnnpack(third_party, labels = []): @@ -21,7 +21,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -82,7 +82,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -129,7 +129,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -184,7 +184,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -236,7 +236,7 @@ def define_qnnpack(third_party, labels = []): ], ), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", ], @@ -291,7 +291,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ("include", "*.h"), ]), - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -398,7 +398,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -465,7 +465,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", "-Wno-unused-command-line-argument", @@ -525,7 +525,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", diff --git a/buckbuild.bzl b/buckbuild.bzl index 4c1affd10e1bc..9f18ad4849dde 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -8,7 +8,7 @@ load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "IOS", "MACOSX") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build") @@ -1090,7 +1090,7 @@ def define_buck_targets( srcs = [ "caffe2/core/common.cc", ], - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = get_pt_compiler_flags(), labels = labels, # @lint-ignore BUCKLINT link_whole diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b353d5d0d5982..217cc8db68864 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") load( "@fbsource//xplat/caffe2/third_party:xnnpack_buck_shim.bzl", "LOGGING_SRCS", @@ -55,7 +55,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_headers = { "xnnpack.h": "XNNPACK/include/xnnpack.h", }, - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -70,7 +70,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = SUBGRAPH_SRCS + ["XNNPACK/src/datatype.c"], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -97,7 +97,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = TABLE_SRCS, headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -121,7 +121,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("scalar"), headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -147,7 +147,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -179,7 +179,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -211,7 +211,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -243,7 +243,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -275,7 +275,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -307,7 +307,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_ssse3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -339,7 +339,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -371,7 +371,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse41_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -403,7 +403,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -443,7 +443,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx", @@ -476,7 +476,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -531,7 +531,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnnigfni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -568,7 +568,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -625,7 +625,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -660,7 +660,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("avxvnni") if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavxvnni", @@ -697,7 +697,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avxvnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -729,7 +729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -770,7 +770,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_f16c_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mf16c", @@ -804,7 +804,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -853,7 +853,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fma3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mfma", @@ -894,7 +894,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -948,7 +948,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx2", @@ -994,7 +994,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1039,7 +1039,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1108,7 +1108,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1141,7 +1141,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1206,7 +1206,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512skx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1259,7 +1259,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -1301,7 +1301,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1350,7 +1350,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1378,7 +1378,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1430,7 +1430,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1460,7 +1460,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default @@ -1532,7 +1532,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1582,7 +1582,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1645,7 +1645,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1690,7 +1690,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1729,7 +1729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1774,7 +1774,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1815,7 +1815,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1860,7 +1860,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1900,7 +1900,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1959,7 +1959,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2004,7 +2004,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2053,7 +2053,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2088,7 +2088,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2114,7 +2114,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2138,7 +2138,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib_ovr_win32", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2165,7 +2165,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2193,7 +2193,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "armv7_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2209,7 +2209,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "prod_ukernels", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2234,7 +2234,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "XNNPACK", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, deps = [ ":tables", From 066c5c57a97ca1876e58040caa7a23b4d3d00065 Mon Sep 17 00:00:00 2001 From: "Sv. Lockal" Date: Wed, 5 Nov 2025 04:13:57 +0000 Subject: [PATCH 050/651] Fix typo in gloo_hip library name (#166502) The typo was never noticed; conditions to enable it require system gloo: `-DUSE_SYSTEM_GLOO=ON -DUSE_GLOO=ON -DUSE_DISTRIBUTED=ON -DUSE_ROCM=ON`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166502 Approved by: https://github.com/jerryzh168, https://github.com/cyyever --- cmake/Modules/FindGloo.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index 944cd4d8d2573..0bdfe275d9c06 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -26,7 +26,7 @@ find_library(Gloo_CUDA_LIBRARY # if Gloo + HIP is desired, Gloo_HIP_LIBRARY # needs to be linked to desired target find_library(Gloo_HIP_LIBRARY - NAMES gloo_hiop + NAMES gloo_hip DOC "Gloo's HIP support/code" ) From 14956eaef4a14901a95a6d0779d99db11fd7406b Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 5 Nov 2025 04:18:04 +0000 Subject: [PATCH 051/651] [ROCm][CI] revert ROCm magma commit hash to last known good (#167044) PR https://github.com/pytorch/pytorch/pull/166693 updated the magma commit hash but this has been linked to ROCm 7.1 CI failures. Go back to last known working magma version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167044 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/magma-rocm/build_magma.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/magma-rocm/build_magma.sh b/.ci/magma-rocm/build_magma.sh index 7d95fed873dc0..c7c7780227ea5 100755 --- a/.ci/magma-rocm/build_magma.sh +++ b/.ci/magma-rocm/build_magma.sh @@ -6,8 +6,8 @@ set -eou pipefail # The script expects DESIRED_CUDA and PACKAGE_NAME to be set ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -# post merge of https://github.com/icl-utk-edu/magma/pull/65 -MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f +# https://github.com/icl-utk-edu/magma/pull/65 +MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec # Folders for the build PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata @@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE # Fetch magma sources and verify checksum pushd ${PACKAGE_DIR} -git clone https://github.com/icl-utk-edu/magma +git clone https://github.com/jeffdaily/magma pushd magma git checkout ${MAGMA_VERSION} popd From 9ffc480c5a928eaccb4ac0e1755a1c596674d884 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 4 Nov 2025 06:46:06 -0800 Subject: [PATCH 052/651] Add min/max support for barebones uint types (#166813) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166813 Approved by: https://github.com/Skylion007 --- aten/src/ATen/cuda/NumericLimits.cuh | 31 +++++++++++++++++++ .../ATen/native/cpu/ReduceAllOpsKernel.cpp | 13 ++++---- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 28 +++++++++-------- .../ATen/native/cpu/TensorCompareKernel.cpp | 13 ++++---- .../ATen/native/cuda/ReduceAMinMaxKernel.cu | 13 ++++---- .../ATen/native/cuda/ReduceMaxValuesKernel.cu | 17 +++++----- .../ATen/native/cuda/ReduceMinValuesKernel.cu | 13 ++++---- .../_internal/common_methods_invocations.py | 14 ++++----- 8 files changed, 90 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh index 7081e94837caa..ebbc004382380 100644 --- a/aten/src/ATen/cuda/NumericLimits.cuh +++ b/aten/src/ATen/cuda/NumericLimits.cuh @@ -55,6 +55,14 @@ struct numeric_limits { static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint16_t lowest() { return 0; } + static inline __host__ __device__ uint16_t max() { return UINT16_MAX; } + static inline __host__ __device__ uint16_t lower_bound() { return 0; } + static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } @@ -63,6 +71,14 @@ struct numeric_limits { static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint32_t lowest() { return 0; } + static inline __host__ __device__ uint32_t max() { return UINT32_MAX; } + static inline __host__ __device__ uint32_t lower_bound() { return 0; } + static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } @@ -71,6 +87,21 @@ struct numeric_limits { static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } }; +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return _UI64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; } +#else + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return UINT64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; } +#endif +}; + template <> struct numeric_limits { #ifdef _MSC_VER diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index c7eaa802af125..c5dbf05039eb1 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return minimum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -199,7 +200,7 @@ void aminmax_allreduce_kernel( } ); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] { using Vec = Vectorized>; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( @@ -214,7 +215,7 @@ void aminmax_allreduce_kernel( [=](Vec a, Vec b) -> Vec { return minimum(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); } } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2e62936501948..3bad49a32d98c 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -347,34 +348,35 @@ struct MinValuesOps: public at::native::MinOps { }; void min_values_kernel_impl(TensorIterator& iter) { - if (iter.dtype() == kLong) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - using scalar_t = int64_t; - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + if (iter.dtype() == kLong || iter.dtype() == kUInt64) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); + }), kLong, kUInt64); return; } - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, static_cast(upper_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); }, [](Vectorized a, Vectorized b) { return maximum(a, b); }, lower_bound()); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void argmax_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index c479e1610cbeb..22c85735ad6ab 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -106,7 +107,7 @@ void min_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -128,7 +129,7 @@ void min_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void max_kernel_impl( @@ -139,7 +140,7 @@ void max_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -161,7 +162,7 @@ void max_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void aminmax_kernel( @@ -186,7 +187,7 @@ void aminmax_kernel( return; } - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] { compare_base_kernel(min_result, max_result, self, wrap_dim, keepdim, [&] ( scalar_t* min_result_data, scalar_t* max_result_data, const scalar_t* self_data, auto self_dim_stride) { @@ -209,7 +210,7 @@ void aminmax_kernel( *max_result_data = max_number; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half); } void where_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu index cdd5daab2d983..0b7823863047a 100644 --- a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -28,22 +29,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { } void aminmax_allreduce_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] { _min_max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void aminmax_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinMaxOps{}, thrust::pair( at::numeric_limits::upper_bound(), at::numeric_limits::lower_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu index e8d1e88ebb3ec..bcbc4c0359943 100644 --- a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -33,27 +34,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) { } void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { + AT_DISPATCH_V2( + iter.dtype(), "max_values_cuda", AT_WRAP([&]() { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "max_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MaxOps{}, thrust::pair( at::numeric_limits::lower_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu index e01ca6c88ebc8..0006a24dbc466 100644 --- a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinOps{}, thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92f212a3c650e..0413c9bf6b6e0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14311,7 +14311,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('max', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14320,7 +14320,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True), OpInfo('max', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), supports_out=True, supports_forward_ad=True, @@ -14465,7 +14465,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False,), OpInfo('min', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14474,7 +14474,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('min', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14784,7 +14784,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), - dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), decorators=(onlyNativeDeviceTypes,), supports_autograd=False, @@ -21126,7 +21126,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amax), skips=( # FIXME: reduces all dimensions when dim=[] @@ -21141,7 +21141,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amin), skips=( # FIXME: reduces all dimensions when dim=[] From c00696144dae1f02e04ce345480b55e46c7d32a8 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 4 Nov 2025 16:09:28 -0800 Subject: [PATCH 053/651] Add model code stack trace to torch.profile (#166677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677 Approved by: https://github.com/ezyang --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 425 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index d6f33d426aee7..c16c42805b921 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,6 +75,12 @@ ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4187,6 +4223,150 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..a61aee321fcff 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines: list[str] = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..fd6835d2b301b 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -798,6 +799,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -807,8 +812,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1760,6 +1779,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1827,6 +1847,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1839,6 +1860,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1849,6 +1871,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From 431dfe8692f3f927c19c739884054d7f1d42a33d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 4 Nov 2025 12:18:14 +0800 Subject: [PATCH 054/651] [dynamo] extend `collections.defaultdict` support with `*args`, `**kwargs` and custom `default_factory` (#166793) Fixes #166238 Extend `collections.defaultdict` to accept `*args` and `**kwargs` in the constructor. And also support custom `default_factory`, such as `dd.default_factory` (a `GetAttrVariable`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/166793 Approved by: https://github.com/guilhermeleobas --- test/dynamo/test_dicts.py | 69 +++++++++++++++++-- ...tDefaultDict.test_keyerror_without_factory | 0 ...13-test_dict-DictTest.test_dict_copy_order | 0 ...redDictSubclassTests.test_sorted_iterators | 0 ...thonOrderedDictTests.test_sorted_iterators | 0 ...313-test_set-TestGraphs.test_cuboctahedron | 0 torch/_dynamo/polyfills/__init__.py | 62 ++++++++++------- torch/_dynamo/variables/builtin.py | 8 ++- torch/_dynamo/variables/dicts.py | 17 ++++- torch/_dynamo/variables/user_defined.py | 21 +++--- 10 files changed, 135 insertions(+), 42 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory delete mode 100644 test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators delete mode 100644 test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 966acd1d81394..4a4d2ff87718f 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -36,6 +36,15 @@ class DummyUserDict(UserDict): pass +class FakeMapping: + def __init__(self, value: Any) -> None: + self._value = value + self.keys = lambda: ["a", "b", "c"] # not required to be a method + + def __getitem__(self, key: str) -> Any: + return self._value + + class DictTests(torch._dynamo.test_case.TestCase): def test_dict_subclass_instantiation(self): def fn(x): @@ -666,6 +675,18 @@ def fn(): for k1, m2 in zip(modules, module_dict.children()): self.assertTrue(modules[k1] is m2) + # FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update + @unittest.expectedFailure + def test_dict_construct_from_mapping_like(self): + def fn(x): + fm = FakeMapping(x) + d = dict(fm, x=x) + return d + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + def test_dict_subclass_initialization_in_graph(self): for super_class in ( OrderedDict, @@ -1087,12 +1108,52 @@ def f(x): self.assertEqual(ref, res) - @unittest.expectedFailure + def test_newly_constructed_default_dict_no_default_factory(self): + def f1(x): + d = defaultdict() + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + x = torch.ones(2) + ref = f1(x) + res = torch.compile(f1, backend="eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + + def f2(x): + d = defaultdict(None) + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + ref = f2(x) + res = torch.compile(f2, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def f3(x): + d = defaultdict(None, {1: 10}) + d[1] += 42 + try: + d[2] += 24 + except KeyError: + d[2] = 1 + return x + 1, d + + ref = f3(x) + res = torch.compile(f3, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + def test_newly_constructed_default_dict_with_dict(self): def f(x): - d = defaultdict(dict, {2: {"a": 1}}) - d[0] = {"b": 2} - return x + 1, d + d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406 + dd = defaultdict(list, d, d=4, e=5) + dd["x"].append(42) + return x + 1, d, dd x = torch.ones(2) ref = f(x) diff --git a/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory b/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron b/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index a8dcf3e00c166..59f6f76317e6d 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -10,7 +10,7 @@ import types from collections import OrderedDict -from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from itertools import repeat as _repeat from operator import eq, ne from typing import Any, TYPE_CHECKING @@ -276,7 +276,7 @@ def getattr_and_trace(*args, **kwargs): return fn(*args[2:], **kwargs) -def mapping_get(obj, key, value=None): +def mapping_get(obj, key, value=None, /): try: return obj.__getitem__(key) except KeyError: @@ -293,31 +293,45 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): return obj -# Used with something like dict(obj) -def construct_dict(cls, /, *args, **kwargs): - dst = cls.__new__(cls) - - if args: - src = args[0] - - if not isinstance(src, Iterable): - raise TypeError(f"{type(src)} object is not iterable") - - # Ensure that the overridden __iter__ method is invoked - if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): - for key in src: - # This will inline the __getitem__ of the src object - dst[key] = src[key] - else: - # likely a sequence like tuple of pairs - for key, value in src: - dst[key] = value +def mutable_mapping_update(self, data=(), /, **kwargs): + if isinstance(data, Mapping): + # Merge standard mapping with PyMapping_Items + for key, value in data.items(): + self[key] = value + # FIXME: Enabling the `elif`-branch below needs too many `VariableClass.call_obj_hasattr` changes. + # >>> class Foo: + # ... def __init__(self): + # ... self.keys = lambda: ['a', 'b', 'c'] # not required to be a method + # ... + # ... def __getitem__(self, key): + # ... return 0 + # ... + # >>> dict(Foo()) + # {'a': 0, 'b': 0, 'c': 0} + # + # > This is a rare case, so we comment it out for now. + # + # elif hasattr(data, "keys"): + # # Merge mapping-like object with PyMapping_Keys + PyObject_GetItem + # for key in data.keys(): + # self[key] = data[key] + else: + if not isinstance(data, Iterable): + raise TypeError(f"{type(data).__name__!r} object is not iterable") + # Likely a sequence of pairs + for key, value in data: + self[key] = value if kwargs: - for key in kwargs: - dst[key] = kwargs[key] + for key, value in kwargs.items(): + self[key] = value - return dst + +# Used with something like dict(obj) +def construct_dict(cls, data=(), /, **kwargs): + self = cls.__new__(cls) + mutable_mapping_update(self, data, **kwargs) + return self def foreach_map_fn(*args): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 1817a5f3c7ed1..0f198377605ec 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2061,7 +2061,11 @@ def call_dir( return None def call_dict( - self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + self, + tx: "InstructionTranslator", + /, + *args: VariableTracker, + **kwargs: VariableTracker, ) -> VariableTracker: return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) @@ -2069,6 +2073,7 @@ def call_dict( def call_custom_dict( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: @@ -2093,6 +2098,7 @@ def call_custom_dict( def call_custom_dict_fromkeys( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 4f1f84a55b0b0..f70ba99c0c93d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -911,6 +911,8 @@ class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict + if default_factory is None: + default_factory = ConstantVariable.create(None) self.default_factory = default_factory def is_python_constant(self): @@ -930,7 +932,13 @@ def is_supported_arg(arg): if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: - return isinstance(arg, variables.functions.BaseUserFunctionVariable) + return isinstance( + arg, + ( + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + ), + ) def call_method( self, @@ -946,8 +954,11 @@ def call_method( if args[0] in self: return self.getitem_const(tx, args[0]) else: - if self.default_factory is None: - raise KeyError(f"{args[0]}") + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[args[0]]) else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 085b5e0c648c5..9dd154dacbb9e 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -419,9 +419,7 @@ def call_method( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): - from .builtin import BuiltinVariable - - return BuiltinVariable.call_custom_dict_fromkeys( + return variables.BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) elif self.value is collections.OrderedDict and name == "move_to_end": @@ -501,15 +499,18 @@ def call_function( [self, *args], kwargs, ) - elif ( - self.value is collections.defaultdict - and len(args) <= 1 - and DefaultDictVariable.is_supported_arg(args[0]) - ): + elif self.value is collections.defaultdict: + if len(args) == 0: + default_factory = variables.ConstantVariable.create(None) + else: + default_factory, *args = args + dict_vt = variables.BuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) return DefaultDictVariable( - {}, + dict_vt.items, collections.defaultdict, - args[0], + default_factory, mutation_type=ValueMutationNew(), ) elif is_typeddict(self.value): From 59a6c83dfe9d88d44d0e5440aa61d2e883a88122 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Wed, 5 Nov 2025 06:39:26 +0000 Subject: [PATCH 055/651] [fx] Add strict argument validation to Interpreter.boxed_run (#166784) # Summary This PR fixes an issue where `torch.fx.Interpreter.boxed_run` would silently ignore extra input arguments instead of validating the argument count. Previously, `boxed_run` would only consume as many inputs as there were placeholder nodes and then clear the entire `args_list`, hiding potential bugs. This change introduces a strict check to ensure `len(args_list)` matches the number of placeholder nodes, raising a `RuntimeError` on a mismatch. Fixes #166583. # Changes * Validate `len(args_list)` against the number of placeholder nodes at the beginning of `boxed_run`. * Raise a `RuntimeError` with a clear message ("extra arguments" or "missing arguments") if the counts do not match. * Move `args_list.clear()` to only execute after successful validation and environment setup. If an error is raised, `args_list` is preserved for debugging. # Testing * Added `test_interpreter_boxed_run_argument_validation` to `test/test_fx.py`. * This test covers three scenarios: 1. Correct number of arguments (succeeds, `args_list` is cleared). 2. Extra arguments (raises `RuntimeError`, `args_list` is preserved). 3. Missing arguments (raises `RuntimeError`, `args_list` is preserved). # User-facing impact / BC notes This is a bug fix. Code that was incorrectly passing the wrong number of arguments to `boxed_run` will now fail fast with a `RuntimeError` instead of executing silently with unintended inputs. Correctly written code is unaffected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166784 Approved by: https://github.com/ezyang, https://github.com/xmfan --- test/test_fx.py | 25 +++++++++++++++++++++++++ torch/fx/interpreter.py | 22 +++++++++++++++++----- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index c16c42805b921..e12189dfea461 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2070,6 +2070,31 @@ def forward(self, x): self.assertEqual(interpreter.run(input), gm(input)) self.assertEqual(interpreter.run(input), m(input)) + def test_interpreter_boxed_run_argument_validation(self): + class AddModule(torch.nn.Module): + def forward(self, lhs, rhs): + return lhs + rhs + + gm = torch.fx.symbolic_trace(AddModule()) + interpreter = Interpreter(gm) + + lhs = torch.tensor(1.0) + rhs = torch.tensor(2.0) + good_args = [lhs.clone(), rhs.clone()] + result = interpreter.boxed_run(good_args) + torch.testing.assert_close(result, lhs + rhs) + self.assertEqual(good_args, []) + + extra_args = [lhs.clone(), rhs.clone(), torch.tensor(3.0)] + with self.assertRaisesRegex(RuntimeError, "extra arguments"): + interpreter.boxed_run(extra_args) + self.assertEqual(len(extra_args), 3) + + missing_args = [lhs.clone()] + with self.assertRaisesRegex(RuntimeError, "missing arguments"): + interpreter.boxed_run(missing_args) + self.assertEqual(len(missing_args), 1) + def test_interpreter_other_graph(self): class MyModule(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a3114a14a657e..5ad1424c4e489 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -220,11 +220,23 @@ def boxed_run(self, args_list): calling convention, where you pass a list of arguments, which will be cleared by the interpreter. This ensures that input tensors are promptly deallocated. """ - args_iter = iter(args_list) - env = {} - for n in self.graph.nodes: - if n.op == "placeholder": - env[n] = next(args_iter) + # Collect placeholder nodes first + placeholder_nodes = [n for n in self.graph.nodes if n.op == "placeholder"] + + # Check argument count + if len(args_list) != len(placeholder_nodes): + detail = ( + "extra arguments" + if len(args_list) > len(placeholder_nodes) + else "missing arguments" + ) + raise RuntimeError( + f"Interpreter.boxed_run expected {len(placeholder_nodes)} arguments for placeholders " + f"but received {len(args_list)} ({detail})" + ) + + # Assign arguments to placeholders + env = dict(zip(placeholder_nodes, args_list)) args_list.clear() return self.run(initial_env=env) From 658c5f879c37142b1df51c7eb6c5a5bb06318597 Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Wed, 5 Nov 2025 06:51:30 +0000 Subject: [PATCH 056/651] [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003) Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036?fbclid=IwY2xjawN3RL1leHRuA2FlbQIxMQBicmlkETExOEcxcnVhNVA1TzRSVmhiAR63GOEpJbZA-JhQ0CSj9ji8H_RHBUhDwYNDtxjOYfDol56OGqmC4r7jPP96Fw_aem_bWvtMfVifLQrnpv1YB_fJA, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs. Test Plan: Inductor test (fbcode): `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"` Tritonbench (fbcode): `clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Tritonbench(oss): `clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Unit Tests(oss): `clear; python test/inductor/test_cutedsl_grouped_mm.py` Differential Revision: D86231180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167003 Approved by: https://github.com/jananisriram --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 + setup.py | 34 ++ test/inductor/test_cutedsl_grouped_mm.py | 154 ++++++++ torch/_inductor/config.py | 4 + torch/_inductor/kernel/mm_common.py | 7 + torch/_inductor/kernel/mm_grouped.py | 93 +++-- .../templates/cutedsl_mm_grouped.py.jinja | 333 ++++++++++++++++++ .../_inductor/template_heuristics/cutedsl.py | 141 ++++++++ torch/_inductor/utils.py | 78 ++++ 10 files changed, 814 insertions(+), 33 deletions(-) create mode 100644 test/inductor/test_cutedsl_grouped_mm.py create mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja create mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 26996b5a32d56..9ae2578758939 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index d1b3b17445dac..3b4323051073a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py +torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index 31e78d0245d93..dd8a52cbeb7c7 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +def mirror_inductor_external_kernels() -> None: + """ + Copy external kernels into Inductor so they are importable. + """ + paths = [ + ( + CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", + CWD + / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", + ), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + + # Copy the files from the orig location to the new location + if orig_path.is_file(): + shutil.copyfile(orig_path, new_path) + continue + if orig_path.is_dir(): + if new_path.exists(): + # copytree fails if the tree exists already, so remove it. + shutil.rmtree(new_path) + shutil.copytree(orig_path, new_path) + continue + raise RuntimeError( + "Check the file paths in `mirror_inductor_external_kernels()`" + ) + + # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1616,6 +1647,8 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() + mirror_inductor_external_kernels() + ( ext_modules, cmdclass, @@ -1649,6 +1682,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", + "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py new file mode 100644 index 0000000000000..c26def3a54099 --- /dev/null +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -0,0 +1,154 @@ +# Owner(s): ["module: inductor"] + + +import unittest + +import torch +from torch import Tensor +from torch._inductor import config +from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import ensure_cute_available +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +@unittest.skipIf( + not (ensure_cute_available() and is_datacenter_blackwell_arch()), + "CuTeDSL library or Blackwell device not available", +) +@instantiate_parametrized_tests +class TestCuTeDSLGroupedGemm(InductorTestCase): + def _get_inputs( + self, + group_size: int, + M_hint: int, + K: int, + N: int, + device: str, + dtype: torch.dtype, + alignment: int = 16, + ) -> tuple[Tensor, Tensor, Tensor]: + # --- Random, tile-aligned M sizes --- + M_sizes = ( + torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) + * alignment + ) + + M_total = torch.sum(M_sizes).item() + + # --- Construct input tensors --- + A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 + B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 + + # --- Build offsets (no leading zero, strictly increasing) --- + offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) + + return (A, B, offsets) + + @parametrize("group_size", (2, 8)) + @parametrize("M_hint", (256, 1024)) + @parametrize("K", (64, 128)) + @parametrize("N", (128, 256)) + def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): + device = "cuda" + dtype = torch.bfloat16 + + A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # Eager execution + c_eager = grouped_gemm_fn(A, B, offsets) + + # Test with Cute backend + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) + @parametrize("layout_B", ("contiguous", "broadcasted")) + def test_grouped_gemm_assorted_layouts( + self, + layout_A: str, + layout_B: str, + ): + device = "cuda" + dtype = torch.bfloat16 + + G, K, N = 8, 64, 128 + M_sizes = [128] * G + sum_M = sum(M_sizes) + offsets = torch.tensor( + [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device + ) + + A_base = torch.randn(sum_M, K, device=device, dtype=dtype) + A = A_base + + if layout_A == "offset": + # allocate bigger buffer than needed, use nonzero storage offset + storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) + offset = 128 # skip first 128 elements + A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) + elif layout_A == "padded": + # simulate row pitch > K (row_stride = K + pad) + row_pitch = K + 8 + storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) + A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) + elif layout_A == "view": + A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) + A = A_storage.view(sum_M, K) + assert A._base is not None + assert A.shape == (sum_M, K) + + B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 + + if layout_B == "broadcasted": + # Broadcast B across groups (zero stride along G) + B = B[0].expand(G, K, N) + assert B.stride(0) == 0 + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # --- eager --- + c_eager = grouped_gemm_fn(A, B, offsets) + + # --- compiled (CUTE backend) --- + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 66eaf69dd59a8..bd1fa7710b06c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,6 +546,10 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b95073e769f31..eb22b95af2afc 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence +from functools import partial +from pathlib import Path from typing import Any import torch @@ -12,6 +14,7 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template log = logging.getLogger(__name__) @@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 881c14fd43d0d..0a44b728a5a93 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import logging -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -18,19 +19,25 @@ TritonTemplate, ) from ..utils import ( + ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, + load_kernel_template, persistent_grouped_mm_grid, ) +if ensure_cute_available(): + from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs + log = logging.getLogger(__name__) aten = torch.ops.aten @@ -513,6 +520,11 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + def grouped_mm_args( mat1: TensorBox, @@ -714,43 +726,44 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -788,6 +801,22 @@ def _tuned_grouped_mm_common( **config.kwargs, ) + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000..989f297c5f80f --- /dev/null +++ b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000..db337b9d8a271 --- /dev/null +++ b/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f8652882af79..efdb4a9a58912 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,6 +1975,84 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From edd8d356b6d9a00cfa34fa323578e5cf1c7e0463 Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Wed, 5 Nov 2025 08:07:42 +0000 Subject: [PATCH 057/651] fixes keyerror when loading parameter with unsaved optimizer state (#165228) Fixes #164257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165228 Approved by: https://github.com/fegin --- torch/distributed/checkpoint/state_dict.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16d988a79103e..9202851537fba 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1007,7 +1007,14 @@ def _split_optim_state_dict( raise AssertionError(f"Expected list, got {type(params)}") params.append(fqn) if param.requires_grad: - state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + if fqn in cast(DictValueType, optim_state_dict[_STATE]): + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + elif info.strict: + raise RuntimeError( + f"Missing optimizer state for parameter '{fqn}' in checkpoint. " + "The parameter requires gradients but has no saved optimizer state. " + "To load anyway, use StateDictOptions(strict=False)." + ) for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): From 0b4dd08e047bda63e1e8dc78f52bcda51562caa5 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 4 Nov 2025 19:59:19 -0800 Subject: [PATCH 058/651] [dynamo] Introduce _set_lru_cache (#167038) Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926. This PR can't be defaulted on, that would be terrible for cache look up times. There's a proper fix in the works by @williamwen42. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038 Approved by: https://github.com/williamwen42 --- test/dynamo/test_repros.py | 88 +++++++++++++++++++++++++++++++ torch/csrc/dynamo/extra_state.cpp | 27 ++++++++-- torch/csrc/dynamo/extra_state.h | 1 + torch/csrc/dynamo/init.cpp | 1 + 4 files changed, 114 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c6138f7574fd4..f3766fe0c973e 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -48,6 +48,7 @@ CompileCounter, CompileCounterWithBackend, EagerAndRecordGraphs, + expectedFailureDynamic, rand_strided, same, skipIfNotPy312, @@ -7455,6 +7456,93 @@ def forward(self, x): msg, ) + @expectedFailureDynamic + def test_dynamo_default_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[1], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[0] + + # Step 3: Run with Step 1's inputs + # LRU cache will match against dynamic shape graph first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], dynamic_shapes_cache_entry) + self.assertEqual(c[1], static_shapes_cache_entry) + + @expectedFailureDynamic + def test_dynamo_disable_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + def run(): + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[0], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[1] + + # Step 3: Run with Step 1's inputs + # LRU cache is disabled, we should still have static entry first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], static_shapes_cache_entry) + self.assertEqual(c[1], dynamic_shapes_cache_entry) + + try: + torch._C._dynamo.eval_frame._set_lru_cache(False) + run() + finally: + torch._C._dynamo.eval_frame._set_lru_cache(True) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index b9dccb456fd65..8dc316b98e63c 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -13,6 +13,11 @@ #define _PyCode_SetExtra PyUnstable_Code_SetExtra #endif +namespace { +// Short-term fix for: https://github.com/pytorch/pytorch/issues/166926 +bool use_lru = true; +} // namespace + Py_ssize_t extra_index = -1; CacheEntry* ExtraState::get_first_entry() { @@ -190,7 +195,9 @@ void lookup( ++index; } if (found) { - extra_state->move_to_front(found); + if (use_lru) { + extra_state->move_to_front(found); + } *maybe_cached_code = found->code.ptr(); *trace_annotation = found->trace_annotation.c_str(); return; @@ -202,8 +209,14 @@ CacheEntry* create_cache_entry( ExtraState* extra_state, PyObject* guarded_code, PyObject* backend) { - extra_state->cache_entry_list.emplace_front(guarded_code, backend); - auto new_iter = extra_state->cache_entry_list.begin(); + std::list::iterator new_iter; + if (use_lru) { + extra_state->cache_entry_list.emplace_front(guarded_code, backend); + new_iter = extra_state->cache_entry_list.begin(); + } else { + extra_state->cache_entry_list.emplace_back(guarded_code, backend); + new_iter = std::prev(extra_state->cache_entry_list.end()); + } new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; // Set guard_manager references to extra_state and CacheEntry @@ -269,6 +282,14 @@ void _load_precompile_entry( extra->precompile_entries.push_back(std::move(entry)); } +void _set_lru_cache(py::object boolean) { + if (py::cast(boolean)) { + use_lru = true; + } else { + use_lru = false; + } +} + py::list _debug_get_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 1630ac90b21dd..bc62e93bf3f1d 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -203,5 +203,6 @@ void _load_precompile_entry( py::object guard_manager, py::object dynamo_code); py::list _debug_get_precompile_entries(const py::handle& code_obj); +void _set_lru_cache(py::object boolean); #endif diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index f1590e19d49cf..790ff9acff3a1 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -254,6 +254,7 @@ void initDynamoBindings(PyObject* torch) { m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries); + m.def("_set_lru_cache", &_set_lru_cache); py::bind_vector>(m, "VectorUInt8"); init_THPCaches(); if (THP_PyOpcode_Caches != nullptr) { From 5c639466f7b1f9453c2a9c0e25b41c3774a12af8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 14:30:15 +0000 Subject: [PATCH 059/651] Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003)" This reverts commit 658c5f879c37142b1df51c7eb6c5a5bb06318597. Reverted https://github.com/pytorch/pytorch/pull/167003 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19093785744/job/54553796743) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/658c5f879c37142b1df51c7eb6c5a5bb06318597) ([comment](https://github.com/pytorch/pytorch/pull/167003#issuecomment-3491527704)) --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 - setup.py | 34 -- test/inductor/test_cutedsl_grouped_mm.py | 154 -------- torch/_inductor/config.py | 4 - torch/_inductor/kernel/mm_common.py | 7 - torch/_inductor/kernel/mm_grouped.py | 93 ++--- .../templates/cutedsl_mm_grouped.py.jinja | 333 ------------------ .../_inductor/template_heuristics/cutedsl.py | 141 -------- torch/_inductor/utils.py | 78 ---- 10 files changed, 33 insertions(+), 814 deletions(-) delete mode 100644 test/inductor/test_cutedsl_grouped_mm.py delete mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja delete mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae2578758939..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index 3b4323051073a..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index dd8a52cbeb7c7..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a54099..0000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bd1fa7710b06c..66eaf69dd59a8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,10 +546,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2afc..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 0a44b728a5a93..881c14fd43d0d 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -19,25 +18,19 @@ TritonTemplate, ) from ..utils import ( - ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) -if ensure_cute_available(): - from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs - log = logging.getLogger(__name__) aten = torch.ops.aten @@ -520,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -726,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -801,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80f..0000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a271..0000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index efdb4a9a58912..3f8652882af79 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,84 +1975,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL backend is enabled - 2. CuTeDSL is available - 3. We are on a blackwell arch - 4. The dtype is bf16 - 5. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - if not _use_autotune_backend("CUTEDSL"): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type): - return False - - if not is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 59563dfe56a086a4a95025f0ccfe373bc1fd3759 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:00 -0800 Subject: [PATCH 060/651] Refactor out headeronly ArrayRef (#164991) Differential Revision: [D85091961](https://our.internmc.facebook.com/intern/diff/D85091961) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991 Approved by: https://github.com/swolchok --- c10/util/ArrayRef.h | 238 ++++++----------- test/cpp/aoti_abi_check/CMakeLists.txt | 1 + .../test_headeronlyarrayref.cpp | 52 ++++ torch/header_only_apis.txt | 3 + torch/headeronly/util/HeaderOnlyArrayRef.h | 247 ++++++++++++++++++ 5 files changed, 388 insertions(+), 153 deletions(-) create mode 100644 test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp create mode 100644 torch/headeronly/util/HeaderOnlyArrayRef.h diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 64605f5153595..1311867ef797e 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -40,200 +41,99 @@ namespace c10 { /// /// This is intended to be trivially copyable, so it should be passed by /// value. +/// +/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct +/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of +/// the underlying constexpr calls, we rely on apparent-type dispatch for +/// inheritance. This should be fine because their memory format is the same, +/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods. +/// However, you should prefer to use ArrayRef when possible, because its use +/// of TORCH_CHECK will lead to better user-facing error messages. template -class ArrayRef final { +class ArrayRef final : public HeaderOnlyArrayRef { public: - using iterator = const T*; - using const_iterator = const T*; - using size_type = size_t; - using value_type = T; - - using reverse_iterator = std::reverse_iterator; - - private: - /// The start of the array, in an external buffer. - const T* Data; - - /// The number of elements. - size_type Length; - - void debugCheckNullptrInvariant() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - Data != nullptr || Length == 0, - "created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal"); - } - - public: - /// @name Constructors + /// @name Constructors, all inherited from HeaderOnlyArrayRef except for + /// SmallVector. As inherited constructors won't work with class template + /// argument deduction (CTAD) until C++23, we add deduction guides after + /// the class definition to enable CTAD. /// @{ - /// Construct an empty ArrayRef. - /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} - - /// Construct an ArrayRef from a single element. - // TODO Make this explicit - constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} - - /// Construct an ArrayRef from a pointer and length. - constexpr ArrayRef(const T* data, size_t length) - : Data(data), Length(length) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a range. - constexpr ArrayRef(const T* begin, const T* end) - : Data(begin), Length(end - begin) { - debugCheckNullptrInvariant(); - } + using HeaderOnlyArrayRef::HeaderOnlyArrayRef; /// Construct an ArrayRef from a SmallVector. This is templated in order to /// avoid instantiating SmallVectorTemplateCommon whenever we /// copy-construct an ArrayRef. + /// NOTE: this is the only constructor that is not inherited from + /// HeaderOnlyArrayRef. template /* implicit */ ArrayRef(const SmallVectorTemplateCommon& Vec) - : Data(Vec.data()), Length(Vec.size()) { - debugCheckNullptrInvariant(); - } - - template < - typename Container, - typename U = decltype(std::declval().data()), - typename = std::enable_if_t< - (std::is_same_v || std::is_same_v)>> - /* implicit */ ArrayRef(const Container& container) - : Data(container.data()), Length(container.size()) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a std::vector. - // The enable_if stuff here makes sure that this isn't used for - // std::vector, because ArrayRef can't work on a std::vector - // bitfield. - template - /* implicit */ ArrayRef(const std::vector& Vec) - : Data(Vec.data()), Length(Vec.size()) { - static_assert( - !std::is_same_v, - "ArrayRef cannot be constructed from a std::vector bitfield."); - } - - /// Construct an ArrayRef from a std::array - template - /* implicit */ constexpr ArrayRef(const std::array& Arr) - : Data(Arr.data()), Length(N) {} - - /// Construct an ArrayRef from a C array. - template - // NOLINTNEXTLINE(*c-arrays*) - /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} - - /// Construct an ArrayRef from a std::initializer_list. - /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) - : Data( - std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) - : std::begin(Vec)), - Length(Vec.size()) {} + : HeaderOnlyArrayRef(Vec.data(), Vec.size()) {} /// @} - /// @name Simple Operations + /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef /// @{ - constexpr iterator begin() const { - return Data; - } - constexpr iterator end() const { - return Data + Length; - } - - // These are actually the same as iterator, since ArrayRef only - // gives you const iterators. - constexpr const_iterator cbegin() const { - return Data; - } - constexpr const_iterator cend() const { - return Data + Length; - } - - constexpr reverse_iterator rbegin() const { - return reverse_iterator(end()); - } - constexpr reverse_iterator rend() const { - return reverse_iterator(begin()); - } - - /// Check if all elements in the array satisfy the given expression - constexpr bool allMatch(const std::function& pred) const { - return std::all_of(cbegin(), cend(), pred); - } - - /// empty - Check if the array is empty. - constexpr bool empty() const { - return Length == 0; - } - - constexpr const T* data() const { - return Data; - } - - /// size - Get the array size. - constexpr size_t size() const { - return Length; - } - /// front - Get the first element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& front() const { TORCH_CHECK( - !empty(), "ArrayRef: attempted to access front() of empty list"); - return Data[0]; + !this->empty(), "ArrayRef: attempted to access front() of empty list"); + return this->Data[0]; } /// back - Get the last element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& back() const { - TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list"); - return Data[Length - 1]; - } - - /// equals - Check for element-wise equality. - constexpr bool equals(ArrayRef RHS) const { - return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + TORCH_CHECK( + !this->empty(), "ArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; } /// slice(n, m) - Take M elements of the array starting at element N + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N, size_t M) const { TORCH_CHECK( - N + M <= size(), + N + M <= this->size(), "ArrayRef: invalid slice, N = ", N, "; M = ", M, "; size = ", - size()); - return ArrayRef(data() + N, M); + this->size()); + return ArrayRef(this->data() + N, M); } /// slice(n) - Chop off the first N elements of the array. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N) const { TORCH_CHECK( - N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); - return slice(N, size() - N); + N <= this->size(), + "ArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); // should this slice be this->slice? } /// @} /// @name Operator Overloads /// @{ - constexpr const T& operator[](size_t Index) const { - return Data[Index]; - } /// Vector compatibility + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& at(size_t Index) const { TORCH_CHECK( - Index < Length, + Index < this->Length, "ArrayRef: invalid index Index = ", Index, "; Length = ", - Length); - return Data[Index]; + this->Length); + return this->Data[Index]; } /// Disallow accidental assignment from a temporary. @@ -253,16 +153,48 @@ class ArrayRef final { std::enable_if_t, ArrayRef>& operator=( std::initializer_list) = delete; - /// @} - /// @name Expensive Operations - /// @{ - std::vector vec() const { - return std::vector(Data, Data + Length); - } - /// @} }; +/// Deduction guides for ArrayRef to support CTAD with inherited constructors +/// These mirror the constructors inherited from HeaderOnlyArrayRef +/// @{ + +// Single element constructor +template +ArrayRef(const T&) -> ArrayRef; + +// Pointer and length constructor +template +ArrayRef(const T*, size_t) -> ArrayRef; + +// Range constructor (begin, end) +template +ArrayRef(const T*, const T*) -> ArrayRef; + +// Generic container constructor (anything with .data() and .size()) +template +ArrayRef(const Container&) -> ArrayRef< + std::remove_pointer_t().data())>>; + +// std::vector constructor +template +ArrayRef(const std::vector&) -> ArrayRef; + +// std::array constructor +template +ArrayRef(const std::array&) -> ArrayRef; + +// C array constructor +template +ArrayRef(const T (&)[N]) -> ArrayRef; + +// std::initializer_list constructor +template +ArrayRef(const std::initializer_list&) -> ArrayRef; + +/// @} + template std::ostream& operator<<(std::ostream& out, ArrayRef list) { int i = 0; diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 4763621f60394..1695e65cb4a1b 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp diff --git a/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp new file mode 100644 index 0000000000000..184c0ade8360e --- /dev/null +++ b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp @@ -0,0 +1,52 @@ +#include + +#include + +#include + +using torch::headeronly::HeaderOnlyArrayRef; + +TEST(TestHeaderOnlyArrayRef, TestEmpty) { + HeaderOnlyArrayRef arr; + ASSERT_TRUE(arr.empty()); +} + +TEST(TestHeaderOnlyArrayRef, TestSingleton) { + float val = 5.0f; + HeaderOnlyArrayRef arr(val); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 1); + EXPECT_EQ(arr[0], val); +} + +TEST(TestHeaderOnlyArrayRef, TestAPIs) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 7); + for (size_t i = 0; i < arr.size(); i++) { + EXPECT_EQ(arr[i], i + 1); + EXPECT_EQ(arr.at(i), i + 1); + } + EXPECT_EQ(arr.front(), 1); + EXPECT_EQ(arr.back(), 7); + ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3))); +} + +TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr({1, 2, 3, 4, 5, 6, 7}); + auto res_vec = arr.vec(); + for (size_t i = 0; i < vec.size(); i++) { + EXPECT_EQ(vec[i], res_vec[i]); + } +} + +TEST(TestHeaderOnlyArrayRef, TestFromRange) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec.data() + 3, vec.data() + 7); + auto res_vec = arr.vec(); + for (size_t i = 0; i < res_vec.size(); i++) { + EXPECT_EQ(vec[i + 3], res_vec[i]); + } +} diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 70165a7493e59..c0cd5d9a2c689 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -42,6 +42,9 @@ fp16_ieee_to_fp32_value # fp32_from_bits called from fp16_ieee_to_fp32_value # fp32_to_bits called from fp16_ieee_from_fp32_value +# torch/headeronly/util/HeaderOnlyArrayRef.h +HeaderOnlyArrayRef + # c10/util/complex.h, torch/headeronly/util/complex.h complex diff --git a/torch/headeronly/util/HeaderOnlyArrayRef.h b/torch/headeronly/util/HeaderOnlyArrayRef.h new file mode 100644 index 0000000000000..2387578ab8f5f --- /dev/null +++ b/torch/headeronly/util/HeaderOnlyArrayRef.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only +/// in headers. This will be a base class from which ArrayRef inherits, so that +/// we can keep much of the implementation shared. +/// +/// [HeaderOnlyArrayRef vs ArrayRef note] +/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less +/// functionality than ArrayRef. We document the minor differences below: +/// 1. ArrayRef has an extra convenience constructor for SmallVector. +/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only +/// STD_TORCH_CHECK, which will output a std::runtime_error vs a +/// c10::Error. Consequently, you should use ArrayRef when possible +/// and HeaderOnlyArrayRef only when necessary to support headeronly code. +/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the +/// positive benefit of being header-only and thus independent of libtorch.so. +template +class HeaderOnlyArrayRef { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + protected: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty HeaderOnlyArrayRef. + /* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {} + + /// Construct a HeaderOnlyArrayRef from a single element. + // TODO Make this explicit + constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct a HeaderOnlyArrayRef from a pointer and length. + constexpr HeaderOnlyArrayRef(const T* data, size_t length) + : Data(data), Length(length) {} + + /// Construct a HeaderOnlyArrayRef from a range. + constexpr HeaderOnlyArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) {} + + template < + typename Container, + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> + /* implicit */ HeaderOnlyArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) {} + + /// Construct a HeaderOnlyArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ HeaderOnlyArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "HeaderOnlyArrayRef cannot be constructed from a std::vector bitfield."); + } + + /// Construct a HeaderOnlyArrayRef from a std::array + template + /* implicit */ constexpr HeaderOnlyArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N]) + : Data(Arr), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a std::initializer_list. + /* implicit */ constexpr HeaderOnlyArrayRef( + const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return this->Data; + } + constexpr iterator end() const { + return this->Data + this->Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return this->Data; + } + constexpr const_iterator cend() const { + return this->Data + this->Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return this->Length == 0; + } + + constexpr const T* data() const { + return this->Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return this->Length; + } + + /// front - Get the first element. + constexpr const T& front() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access front() of empty list"); + return this->Data[0]; + } + + /// back - Get the last element. + constexpr const T& back() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(HeaderOnlyArrayRef RHS) const { + return this->Length == RHS.Length && + std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr HeaderOnlyArrayRef slice(size_t N, size_t M) const { + STD_TORCH_CHECK( + N + M <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + this->size()); + return HeaderOnlyArrayRef(this->data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr HeaderOnlyArrayRef slice(size_t N) const { + STD_TORCH_CHECK( + N <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return this->Data[Index]; + } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + STD_TORCH_CHECK( + Index < this->Length, + "HeaderOnlyArrayRef: invalid index Index = ", + Index, + "; Length = ", + this->Length); + return this->Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(this->Data, this->Data + this->Length); + } + + /// @} +}; + +} // namespace c10 + +namespace torch::headeronly { +using c10::HeaderOnlyArrayRef; +using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef; +} // namespace torch::headeronly From 7a6ff88196e12f9eebc8769d5fcbb8225a047e28 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:01 -0800 Subject: [PATCH 061/651] Widen ops support to take in IntHOArrayRef vs only std::vec (#165152) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991 --- .../libtorch_agnostic/csrc/kernel.cpp | 12 ++++++------ torch/csrc/stable/ops.h | 17 +++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 58c812b08cccb..87aaa46e64c95 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -311,10 +311,9 @@ void boxed_fill_infinity( } Tensor my_pad(Tensor t) { - std::vector padding = {1, 2, 2, 1}; std::string mode = "constant"; double value = 0.0; - return pad(t, padding, mode, value); + return pad(t, {1, 2, 2, 1}, mode, value); } void boxed_my_pad( @@ -342,6 +341,9 @@ void boxed_my_narrow( } Tensor my_new_empty_dtype_variant(Tensor t) { + // Still using a std::vector below even though people can just pass in an + // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef) + // directly. std::vector sizes = {2, 5}; auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); @@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui } Tensor my_new_zeros_dtype_variant(Tensor t) { - std::vector sizes = {2, 5}; auto dtype = std::make_optional(at::ScalarType::Float); - return new_zeros(t, sizes, dtype); + return new_zeros(t, {2, 5}, dtype); } void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { @@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) } Tensor my_amax_vec(Tensor t) { - std::vector v = {0,1}; - return amax(t, v, false); + return amax(t, {0,1}, false); } void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 5c2959e69ae0b..d5fbba9fbbfd7 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -5,13 +5,13 @@ #include #include #include -#include #include #include #include #include #include +#include HIDDEN_NAMESPACE_BEGIN(torch, stable) @@ -68,7 +68,7 @@ inline torch::stable::Tensor narrow( // only dtype information. inline torch::stable::Tensor new_empty( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -107,7 +107,7 @@ inline torch::stable::Tensor new_empty( // only dtype information. inline torch::stable::Tensor new_zeros( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -144,12 +144,10 @@ inline torch::stable::Tensor new_zeros( // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as -// use std::vector because -// (1) IntArrayRef is not yet header-only -// (2) SymInt is not yet header-only +// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only. inline torch::stable::Tensor pad( const torch::stable::Tensor& self, - std::vector pad, + torch::headeronly::IntHeaderOnlyArrayRef pad, const std::string& mode = "constant", double value = 0.0) { AtenTensorHandle ret0 = nullptr; @@ -181,11 +179,10 @@ inline torch::stable::Tensor amax( // This function is an overload to compute the maximum value along each slice of // `self` reducing over all the dimensions in the vector `dims`. The // amax.default op takes in a SymInt[] as the dims argument, however dims is -// typed as use std::vector here because (1) IntArrayRef is not yet -// header-only (2) SymInt is not yet header-only +// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only inline torch::stable::Tensor amax( const torch::stable::Tensor& self, - std::vector dims, + torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim = false) { AtenTensorHandle ret = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax( From d2d13bf62dc848348196f91d3f104f84ac1e47e7 Mon Sep 17 00:00:00 2001 From: eellison Date: Wed, 5 Nov 2025 05:54:27 -0800 Subject: [PATCH 062/651] Invert unary read and write for fusion (#161404) For [this repro](https://gist.github.com/eellison/75a99616a0fcca0436316bbfd8987fae) enables fusion of `to_blocked` with the prior `to_mx` calculation, so that there is only a single kernel per tensor, resulting in a 10% speedup of the non conversion code (need to update my local devserver to 12.9 to time the matmul as well). The `to_mx` kernel has a contiguous write: ```Py op6_op7: FusedSchedulerNode(SchedulerNode,SchedulerNode) op6_op7.writes = [MemoryDep('buf6', c0, {c0: 2097152}), MemoryDep('buf7', c0, {c0: 67108864})] op6_op7.unmet_dependencies = [] op6_op7.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 67108864})] op6_op7.outputs = [ buf6: ComputedBuffer buf6.layout = FixedLayout('cuda:0', torch.float32, size=[8192, 256], stride=[256, 1]) buf6.users = [ NodeUser(node=SchedulerNode(name='op7'), can_inplace=False, is_weak=False), NodeUser(node=SchedulerNode(name='op9'), can_inplace=False, is_weak=False), ] buf7: ComputedBuffer buf7.layout = FixedLayout('cuda:0', torch.float8_e4m3fn, size=[8192, 256, 32], stride=[8192, 32, 1]) buf7.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)] ] ``` While the `to_blocked` has a single discontiguous read and a single contiguous write. ```Py op9: SchedulerNode(ComputedBuffer) op9.writes = [MemoryDep('buf9', c0, {c0: 2097152})] op9.unmet_dependencies = [ MemoryDep('buf6', 32768*((c0//32768)) + 8192*(((ModularIndexing(c0, 1, 16))//4)) + 256*(ModularIndexing(c0, 16, 32)) + 4*(ModularIndexing(c0, 512, 64)) + (ModularIndexing(ModularIndexing(c0, 1, 16), 1, 4)), {c0: 2097152})] op9.met_dependencies = [] op9.outputs = [ buf9: ComputedBuffer buf9.layout = FixedLayout('cuda:0', torch.float8_e8m0fnu, size=[2097152], stride=[1]) buf9.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)] ] ``` To enable fusion, we invert the read, giving op9 and contiguous read and discontiguous write. More explanation here: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 [Tlparse with this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000). [Tlparse without this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000). Pull Request resolved: https://github.com/pytorch/pytorch/pull/161404 Approved by: https://github.com/shunting314 --- test/inductor/test_fp8.py | 236 +++++++++++++++++++++++- test/inductor/test_loop_ordering.py | 108 +++++++++++ torch/_inductor/config.py | 11 ++ torch/_inductor/invert_expr_analysis.py | 208 +++++++++++++++++++++ torch/_inductor/loop_body.py | 4 +- torch/_inductor/scheduler.py | 158 +++++++++++++++- 6 files changed, 722 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/invert_expr_analysis.py diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index f26a2347e4e86..f1067b8ffebb3 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +from torch._C import FileCheck from torch._inductor import config, utils from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.test_case import run_tests, TestCase @@ -29,7 +30,6 @@ HAS_CPU, HAS_CUDA_AND_TRITON, ) -from torch.testing._internal.jit_utils import FileCheck from torch.utils._triton import has_triton_tma_device @@ -953,6 +953,240 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @torch._inductor.config.patch("emulate_precision_casts", True) + def test_mx_fusion(self): + # Register fake_scaled_mm custom op scoped to this test + with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib: + # Define the op schema + lib.define( + "fake_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, " + "Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, " + "bool use_fast_accum=False) -> Tensor" + ) + input_values = [] + + # Register CUDA implementation + @torch.library.impl(lib, "fake_scaled_mm", "CUDA") + def fake_scaled_mm_impl( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """Software-emulated scaled_mm for testing without CUDA 12.8""" + out_dtype = out_dtype or torch.bfloat16 + # just using add, because without real dtypes, + # was seeing overflow/instability + nonlocal input_values + input_values.append((mat_a, mat_b, scale_a, scale_b)) + result = mat_a.to(torch.float32) + mat_b.to(torch.float32) + if bias is not None: + result = result + bias.to(torch.float32) + return result.to(out_dtype) + + # Register fake implementation + @torch.library.impl(lib, "fake_scaled_mm", "Meta") + def fake_scaled_mm_meta( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """FakeTensor implementation""" + out_dtype = out_dtype or torch.bfloat16 + M, K = mat_a.shape + K2, N = mat_b.shape + torch._check( + K == K2, + lambda: f"Incompatible shapes: {mat_a.shape} @ {mat_b.shape}", + ) + return torch.empty((M, N), dtype=out_dtype, device=mat_a.device) + + def forward( + arg0_1, + arg1_1, + ): + view = torch.ops.aten.reshape.default(arg0_1, [8192, 256, 32]) + abs_1 = torch.ops.aten.abs.default(view) + amax = torch.ops.aten.amax.default(abs_1, [-1]) + unsqueeze = torch.ops.aten.unsqueeze.default(amax, -1) + view_1 = torch.ops.aten.view.dtype(unsqueeze, torch.int32) + bitwise_right_shift = torch.ops.aten.bitwise_right_shift.Tensor_Scalar( + view_1, 23 + ) + bitwise_and = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift, 255 + ) + sub = torch.ops.aten.sub.Tensor(bitwise_and, 127) + sub_1 = torch.ops.aten.sub.Tensor(sub, 8) + clamp_min = torch.ops.aten.clamp_min.default(sub_1, -127) + clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 128) + add = torch.ops.aten.add.Tensor(clamp_max, 127) + convert_element_type = torch.ops.prims.convert_element_type.default( + add, torch.uint8 + ) + isnan = torch.ops.aten.isnan.default(unsqueeze) + scalar_tensor = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where = torch.ops.aten.where.self( + isnan, scalar_tensor, convert_element_type + ) + convert_element_type_1 = torch.ops.prims.convert_element_type.default( + where, torch.int32 + ) + bitwise_left_shift = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_1, 23 + ) + view_2 = torch.ops.aten.view.dtype(bitwise_left_shift, torch.float32) + clamp_min_1 = torch.ops.aten.clamp_min.default( + view_2, 1.1754943508222875e-38 + ) + div = torch.ops.aten.div.Tensor(view, clamp_min_1) + clamp_min_2 = torch.ops.aten.clamp_min.default(div, -448.0) + clamp_max_1 = torch.ops.aten.clamp_max.default(clamp_min_2, 448.0) + convert_element_type_2 = torch.ops.prims.convert_element_type.default( + clamp_max_1, torch.float8_e4m3fn + ) + view_3 = torch.ops.aten.reshape.default( + convert_element_type_2, [8192, 8192] + ) + convert_element_type_2 = None + view_4 = torch.ops.aten.view.dtype(where, torch.float8_e8m0fnu) + squeeze = torch.ops.aten.squeeze.dim(view_4, -1) + + view_5 = torch.ops.aten.reshape.default(arg1_1, [8192, 256, 32]) + abs_2 = torch.ops.aten.abs.default(view_5) + amax_1 = torch.ops.aten.amax.default(abs_2, [-1]) + unsqueeze_1 = torch.ops.aten.unsqueeze.default(amax_1, -1) + view_6 = torch.ops.aten.view.dtype(unsqueeze_1, torch.int32) + bitwise_right_shift_1 = ( + torch.ops.aten.bitwise_right_shift.Tensor_Scalar(view_6, 23) + ) + bitwise_and_1 = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift_1, 255 + ) + sub_2 = torch.ops.aten.sub.Tensor(bitwise_and_1, 127) + sub_3 = torch.ops.aten.sub.Tensor(sub_2, 8) + clamp_min_3 = torch.ops.aten.clamp_min.default(sub_3, -127) + clamp_max_2 = torch.ops.aten.clamp_max.default(clamp_min_3, 128) + add_1 = torch.ops.aten.add.Tensor(clamp_max_2, 127) + convert_element_type_3 = torch.ops.prims.convert_element_type.default( + add_1, torch.uint8 + ) + isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1) + unsqueeze_1 = None + scalar_tensor_1 = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where_1 = torch.ops.aten.where.self( + isnan_1, scalar_tensor_1, convert_element_type_3 + ) + convert_element_type_4 = torch.ops.prims.convert_element_type.default( + where_1, torch.int32 + ) + bitwise_left_shift_1 = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_4, 23 + ) + convert_element_type_4 = None + view_7 = torch.ops.aten.view.dtype(bitwise_left_shift_1, torch.float32) + bitwise_left_shift_1 = None + clamp_min_4 = torch.ops.aten.clamp_min.default( + view_7, 1.1754943508222875e-38 + ) + div_1 = torch.ops.aten.div.Tensor(view_5, clamp_min_4) + clamp_min_5 = torch.ops.aten.clamp_min.default(div_1, -448.0) + clamp_max_3 = torch.ops.aten.clamp_max.default(clamp_min_5, 448.0) + convert_element_type_5 = torch.ops.prims.convert_element_type.default( + clamp_max_3, torch.float8_e4m3fn + ) + view_8 = torch.ops.aten.reshape.default( + convert_element_type_5, [8192, 8192] + ) + view_9 = torch.ops.aten.view.dtype(where_1, torch.float8_e8m0fnu) + squeeze_1 = torch.ops.aten.squeeze.dim(view_9, -1) + + permute = torch.ops.aten.permute.default(view_8, [1, 0]) + + view_13 = torch.ops.aten.reshape.default(squeeze, [64, 128, 64, 4]) + permute_2 = torch.ops.aten.permute.default(view_13, [0, 2, 1, 3]) + clone = torch.ops.aten.clone.default( + permute_2, memory_format=torch.contiguous_format + ) + view_14 = torch.ops.aten.reshape.default(clone, [4096, 4, 32, 4]) + permute_3 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3]) + clone_1 = torch.ops.aten.clone.default( + permute_3, memory_format=torch.contiguous_format + ) + view_15 = torch.ops.aten.reshape.default(clone_1, [4096, 32, 16]) + + view_16 = torch.ops.aten.reshape.default(view_15, [2097152]) + + view_18 = torch.ops.aten.reshape.default(squeeze_1, [64, 128, 64, 4]) + permute_5 = torch.ops.aten.permute.default(view_18, [0, 2, 1, 3]) + clone_2 = torch.ops.aten.clone.default( + permute_5, memory_format=torch.contiguous_format + ) + view_19 = torch.ops.aten.reshape.default(clone_2, [4096, 4, 32, 4]) + permute_6 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]) + clone_3 = torch.ops.aten.clone.default( + permute_6, memory_format=torch.contiguous_format + ) + view_20 = torch.ops.aten.reshape.default(clone_3, [4096, 32, 16]) + + view_21 = torch.ops.aten.reshape.default(view_20, [2097152]) + + _scaled_mm = torch.ops.test_fp8.fake_scaled_mm.default( + view_3, permute, view_16, view_21, None, None, torch.float32 + ) + return (_scaled_mm,) + + # Run with largest shape + M, K, N = 8192, 8192, 8192 + device = "cuda" + + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + f_c = torch.compile(fullgraph=True)(forward) + + _, code = run_and_get_code(f_c, A, B) + + FileCheck().check(".run(").check(".run(").check("fake_scaled_mm").run( + code[0] + ) + + for seed in range(5): + input_values.clear() + torch.manual_seed(seed) + # without dividing, outputs get way too large + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + + # Uses fake_scaled_mm custom op (no CUDA 12.8 needed!) + torch._dynamo.reset() + torch.compile(forward)(A, B) + + torch._dynamo.reset() + with config.patch({"loop_index_inversion_in_fusion": False}): + torch.compile(forward)(A, B) + + assert len(input_values) == 2 + for i in range(4): + self.assertEqual( + input_values[0][i], + input_values[1][i], + msg=f"idx {i} seed {seed}", + ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index c77b3574b2227..051a5f5905997 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -16,6 +16,7 @@ from torch._inductor import config as inductor_config, ir, metrics from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.graph import GraphLowering +from torch._inductor.invert_expr_analysis import generate_inverse_formula from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize @@ -1188,6 +1189,113 @@ def fn(nodes): torch.compile(f)(x) +class TestIndexInversion(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + gm = torch.fx.symbolic_trace(lambda: 0) + graph = GraphLowering(gm) + graph.scheduler = MockScheduler + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context(V.set_graph_handler(graph)) + + def _check_expr(self, expr, reconstruction, val_range): + import numpy as np + from sympy import lambdify + + assert len(expr.free_symbols) == 1 + p0 = next(iter(expr.free_symbols)) + + def floordiv_replacement(a, b): + """Replace FloorDiv(a, b) with a // b""" + return a // b + + def modularindexing_replacement(x, base, divisor): + """Replace ModularIndexing(x, base, divisor) with (x // base) % divisor""" + return (x // base) % divisor + + # Replace custom functions with sympy equivalents + expr_numpy_ready = expr.replace(FloorDiv, floordiv_replacement).replace( + ModularIndexing, modularindexing_replacement + ) + reconstruction_numpy_ready = reconstruction.replace( + FloorDiv, floordiv_replacement + ).replace(ModularIndexing, modularindexing_replacement) + + # Now lambdify with standard numpy + forward_func = lambdify(p0, expr_numpy_ready, modules="numpy") + inverse_func = lambdify(p0, reconstruction_numpy_ready, modules="numpy") + + test_values = np.arange(0, val_range, dtype=np.int64) + forward_values = forward_func(test_values).astype(np.int64) + recovered_values = inverse_func(forward_values).astype(np.int64) + torch.testing.assert_close(test_values, recovered_values) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._exit_stack.close() + + def test_original_complex_expression(self): + """Test the original motivating complex expression.""" + p0 = sympy.Symbol("p0") + expr = ( + 32768 * FloorDiv(p0, 32768) + + 8192 * FloorDiv(ModularIndexing(p0, 1, 16), 4) + + ModularIndexing(p0, 1, 4) + + 256 * ModularIndexing(p0, 16, 32) + + 4 * ModularIndexing(p0, 512, 64) + ) + + reconstruction = generate_inverse_formula(expr, p0) + self.assertIsNotNone(reconstruction) + self._check_expr(expr, reconstruction, 2097152) + + def test_inversion_cases(self): + """Test various expressions for correct inversion behavior.""" + p = sympy.Symbol("p") + + cases = [ + # (expression, should_be_invertible, test_range) + # Simple 2-term base-10 style: 10 = 1 × 10 ✓ + (10 * ModularIndexing(p, 10, 10) + ModularIndexing(p, 1, 10), True, 100), + # Simple 2-term base-2 style: 2 = 1 × 2 ✓ + (2 * ModularIndexing(p, 2, 2) + ModularIndexing(p, 1, 2), True, 4), + # 3-term decimal: 100 = 10×10, 10 = 1×10 ✓ + ( + 100 * FloorDiv(p, 100) + + 10 * FloorDiv(ModularIndexing(p, 1, 100), 10) + + ModularIndexing(p, 1, 10), + True, + 1000, + ), + (4 * p, False, 64), # expr and inverse not bijections + # when sorted, invertible + (ModularIndexing(p, 1, 10) + 10 * ModularIndexing(p, 10, 10), True, None), + # Wrong coefficient ratios: 4 ≠ 1×2 + (4 * ModularIndexing(p, 1, 8) + ModularIndexing(p, 8, 2), False, None), + ( + 100 * FloorDiv(p, 100) + 7 * ModularIndexing(p, 1, 100), + False, + None, + ), # Wrong ratios + (FloorDiv(p, 100) + FloorDiv(p, 10) + p, False, None), # Overlapping ranges + (p**2 + 10 * p + 1, False, None), # Quadratic + (sympy.sin(p) + sympy.cos(p), False, None), # Trigonometric + ] + + for expr, should_invert, test_range in cases: + reconstruction = generate_inverse_formula(expr, p) + + if should_invert: + self.assertIsNotNone(reconstruction, f"Expected invertible: {expr}") + # Test correctness on sample values + self._check_expr(expr, reconstruction, test_range) + else: + self.assertIsNone(reconstruction, f"Expected non-invertible: {expr}") + + if __name__ == "__main__": if HAS_GPU: run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 66eaf69dd59a8..aaf7fbd2f7f54 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -674,6 +674,17 @@ def use_autoheuristic(name: str) -> bool: == "1" ) + +# When trying to fuse two nodes, one with: +# a[contiguous_writes] = fn(...) +# and another node: +# b[contiguous_writes] = a[discontiguous_reads] +# If b is unary, and we can figure out an inverse formula for +# discontiguous writes, invert b as : +# b[inverse(discontiguous_writes)] = a[contiguous_reads] +# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 +loop_index_inversion_in_fusion: bool = True + # If fusing two nodes only save less then score_fusion_memory_threshold memory, # we should not bother fusing the nodes. # diff --git a/torch/_inductor/invert_expr_analysis.py b/torch/_inductor/invert_expr_analysis.py new file mode 100644 index 0000000000000..816482dba020c --- /dev/null +++ b/torch/_inductor/invert_expr_analysis.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from typing import Optional + +import sympy + +from torch._inductor.utils import _IntLike, argsort_sym +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from .virtualized import V + + +def static_eq(a: _IntLike, b: _IntLike) -> bool: + return V.graph.sizevars.statically_known_equals(a, b) + + +@dataclass +class Term: + coefficient: _IntLike + range: Optional[_IntLike] # None for unbounded + original_expr: sympy.Expr + reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction + + +def generate_inverse_formula( + expr: sympy.Expr, var: sympy.Symbol +) -> Optional[sympy.Expr]: + """ + Analyze an expression to see if it matches a specific invertible pattern that we + know how to reverse. + + We're looking for expressions that are sums of terms where each term extracts a + distinct bounded range from the input variable, like: + + y = c₀*a₀ + c₁*a₁ + c₂*a₂ + ... + cₙ*aₙ + + where each aᵢ must be one of these specific patterns: + - ModularIndexing(var, divisor, modulo) + - FloorDiv(ModularIndexing(var, 1, modulo), divisor) + - FloorDiv(var, divisor) + - var (the variable itself) + + The key pattern we need is: + - Coefficients are strictly decreasing: c₀ > c₁ > c₂ > ... > cₙ + - Each coefficient matches the product of ranges of later terms (mixed-radix property) + - Each term extracts a bounded range, creating non-overlapping "slots" + + If we find this pattern, we can generate the reconstruction transformation that + decomposes the variable and rebuilds it using the correct multipliers. + + EXAMPLE: + Input: 100*((p//100)) + 10*((p%100)//10) + (p%10) + + Returns the reconstruction expression: + remainder₀ = p + component₀ = remainder₀ // 100 # hundreds digit + remainder₁ = remainder₀ % 100 + component₁ = remainder₁ // 10 # tens digit + remainder₂ = remainder₁ % 10 + component₂ = remainder₂ # ones digit + result = component₀*100 + component₁*10 + component₂*1 + + This decomposes p into its components and rebuilds it using the original + multipliers, which should equal the input expression. + + Args: + expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.) + var: The variable being decomposed + + Returns: + None if not invertible, or the reconstruction expression + + References: + Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix + """ + # Step 1: Parse all terms + terms = parse_terms(expr, var) + if not terms: + return None + + # Step 2: Sort by coefficient (descending) + coeffs = [t.coefficient for t in terms] + idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs)) + terms = [terms[i] for i in idxs] + + # Step 3: Check invertibility conditions + if not check_invertibility(terms): + return None + + return generate_reconstruction_expr(terms, var) + + +def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]: + """Parse expression into terms.""" + if not isinstance(expr, sympy.Add): + # Single term + term = parse_single_term(expr, var) + return [term] if term else [] + + terms = [] + for arg in expr.args: + term = parse_single_term(arg, var) + if term: + terms.append(term) + else: + return None # If any term fails to parse, fail completely + + return terms + + +def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]: + """Parse a single term and extract coefficient, range, and reconstruction multiplier.""" + # Extract coefficient and expression parts + coefficient, expr_parts = term.as_coeff_mul() + + if len(expr_parts) == 0: + # Pure constant term + return Term( + coefficient=coefficient, + range=1, + original_expr=1, + reconstruction_multiplier=0, + ) + elif len(expr_parts) == 1: + expr = expr_parts[0] + else: + # Multiple non-constant factors, too complex + return None + + # Now determine the range and reconstruction multiplier + range_val, reconstruction_multiplier = analyze_expression_properties(expr, var) + if reconstruction_multiplier is None: + return None + + return Term( + coefficient=coefficient, + range=range_val, + original_expr=expr, + reconstruction_multiplier=reconstruction_multiplier, + ) + + +def analyze_expression_properties( + expr: sympy.Expr, var: sympy.Symbol +) -> tuple[Optional[_IntLike], Optional[_IntLike]]: + """Analyze an expression to determine its range and reconstruction multiplier.""" + # ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo + if isinstance(expr, ModularIndexing): + x, div, mod = expr.args + if static_eq(x, var): + return mod, div # Range is mod, multiplier is div + + # FloorDiv cases + if isinstance(expr, FloorDiv): + base, divisor = expr.args + + # FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div + if isinstance(base, ModularIndexing): + x, inner_div, mod = base.args + if static_eq(x, var) and static_eq(inner_div, 1): + range_val = FloorDiv(mod, divisor) + return range_val, divisor # Range is mod//div, multiplier is div + + # FloorDiv(var, divisor) = var // divisor (unbounded) + elif static_eq(base, var): + return None, divisor # Unbounded range, multiplier is div + + return None, None + + +def check_invertibility(terms: list[Term]) -> bool: + """Check if the terms represent an invertible transformation.""" + if not terms: + return False + + # Coefficients must be strictly decreasing + coeffs = [t.coefficient for t in terms] + if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list( + reversed(range(len(coeffs))) + ): + return False + + # Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1] + expected_coeff = 1 + for term in reversed(terms): + if not static_eq(term.coefficient, expected_coeff): + return False + if term.range is not None: + expected_coeff *= term.range + + return True + + +def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr: + y = var + reconstruction = sympy.S.Zero + remainder = y + + for i, term in enumerate(terms): + if i < len(terms) - 1: + component = FloorDiv(remainder, term.coefficient) + remainder = ModularIndexing(remainder, 1, term.coefficient) + else: + # Last term should also divide by its coefficient + component = FloorDiv(remainder, term.coefficient) + + reconstruction += component * term.reconstruction_multiplier + + return reconstruction diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 53ae1d8f63f6b..3921aa955a836 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -95,7 +95,6 @@ class LoopBody: """ indexing_exprs: dict[str, sympy.Expr] - indexing_exprs_name: dict[sympy.Expr, str] submodules: dict[str, Any] subblocks: dict[str, LoopBodyBlock] indirect_vars: list[sympy.Symbol] @@ -104,6 +103,9 @@ class LoopBody: memory_usage: dict[MemoryUsageType, list[MemoryEntry]] op_counts: collections.Counter[str] + # defined only temporarily + indexing_exprs_name: dict[sympy.Expr, str] + def __init__( self, fn, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index df1d2f729b34a..2930a33b465a6 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3345,7 +3345,10 @@ def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: ) break - if config.loop_ordering_after_fusion: + if ( + config.loop_ordering_after_fusion + or config.loop_index_inversion_in_fusion + ): nodes = self.fuse_nodes_once(nodes, is_reorder_round=True) return nodes @@ -4302,6 +4305,148 @@ def decide_fusion_fail_reason( return str(reasons) + def shared_data_after_inverting_indexing( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Attempts to enable fusion between two nodes by inverting indexing patterns. + + This optimization targets cases where node1 has a contiguous write and + node2 has a contiguous write but discontiguous read. By inverting the + indexing in node2's read and write operations, we can make them compatible + with node1 for potential fusion. + + Args: + node1: First scheduler node (source) + node2: Second scheduler node (target for inversion) + + Returns: + int: Fusion score if successful, 0 if optimization not applicable + """ + + if not config.loop_index_inversion_in_fusion: + return -1 + + if any(n.is_cpu() for n in [node1, node2]): + return -1 + + # Check for shared buffers between nodes + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + common_buffer_names = node1_buffer_names & node2_buffer_names + + if not common_buffer_names: + return -1 + + # only invert if node1 is single unmet dep + node2_unmet_dependencies = OrderedSet( + dep.name for dep in node2.unmet_dependencies + ) + if node2_unmet_dependencies - node1_buffer_names: + return -1 + + if len(node2_unmet_dependencies) > 1: + return -1 + + # Currently only handle single read/write operations + if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1: + return -1 + + node2_read = next(iter(node2.read_writes.reads)) + node2_write = next(iter(node2.read_writes.writes)) + + if not isinstance(node2_read, MemoryDep) or not isinstance( + node2_write, MemoryDep + ): + return -1 + + node1_writes = {dep.name: dep for dep in node1.read_writes.writes} + if node2_read.name not in node1_writes: + return -1 + + node1_write = node1_writes[node2_read.name] + + if not isinstance(node1_write, MemoryDep): + return -1 + + # We are checking for compatibility with the normalized node1 write + # then modifying node2 reads/writes. since the node1 write will be just used + # for compatibility, while node2 will be used in actual modification, just + # normalize node1 not node2. + node1_write = node1_write.normalize() + + if ( + node1_write.index != node2_write.index + and node1_write.size != node2_write.size + ): + return -1 + + if node2_read.size != node2_write.size or len(node2_read.var_names) != 1: + return -1 + + # Verify we have exactly two indexing expressions (one read, one write) + if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined] + return -1 + + # No subblocks allowed for this optimization + if node2._body.subblocks: # type: ignore[attr-defined] + return -1 + + assert ( + "index0" in node2._body.indexing_exprs # type: ignore[attr-defined] + and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined] + ) + + # Extract and verify single read expression + node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined] + if len(node2_read_exprs) != 1: + return -1 + + read_expr = next(iter(node2_read_exprs)) + + # Determine which index is for reading vs writing + if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined] + read_expr_index = "index0" + write_expr_index = "index1" + else: + assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined] + read_expr_index = "index1" + write_expr_index = "index0" + + from torch._inductor.invert_expr_analysis import generate_inverse_formula + + index_vars = node2._body.vars[0] # type: ignore[attr-defined] + if len(index_vars) != 1: + return -1 + + simplified_terms = [] + for term in sympy.Add.make_args(read_expr): + simplified_terms.append( + V.graph.sizevars.combine_modular_indexing_pairs(term) + ) + simplified_read_expr = sum(simplified_terms) + + inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0]) + + # formula is not invertible + if inverse_formula is None: + return -1 + + # === Apply Inversion === + + # Swap the indexing expressions using the inverse formula + node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined] + write_expr_index + ] + node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined] + + # Refresh dependencies and calculate fusion score + node2.refresh_dependencies(True, False) # type: ignore[attr-defined] + score = self.score_fusion_memory(node1, node2) + + fusion_log.info("Shared memory after inversion: %d", score) + return score + def shared_data_after_reordering_loop( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> int: @@ -4686,6 +4831,7 @@ def can_fuse( del device2 shared_data_score = self.score_fusion_memory(node1, node2) + if ( can_reorder and shared_data_score < config.score_fusion_memory_threshold @@ -4702,6 +4848,16 @@ def can_fuse( smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) shared_data_score = self.score_fusion_memory(node1, node2) + if ( + config.loop_index_inversion_in_fusion + and shared_data_score < config.score_fusion_memory_threshold + ): + new_shared_data_score = self.shared_data_after_inverting_indexing( + node1, node2 + ) + if new_shared_data_score >= 0: + shared_data_score = new_shared_data_score + if loop_ordering_log.isEnabledFor(logging.DEBUG): loop_ordering_log.debug( "%s and %s has %s shared data", From aba2fa32593c6d7cfa55d488814984c421eaafb7 Mon Sep 17 00:00:00 2001 From: Siddhartha Menon Date: Wed, 5 Nov 2025 16:55:51 +0000 Subject: [PATCH 063/651] Fix clang-21 warnings (#166859) Fixes compiler warnings thrown by Clang-21 Fixes #166755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166859 Approved by: https://github.com/aditew01, https://github.com/fadara01, https://github.com/malfet --- aten/src/ATen/cpu/vec/sve/vec_bfloat16.h | 84 +++++++++---------- .../src/ATen/native/cpu/GridSamplerKernel.cpp | 4 +- caffe2/CMakeLists.txt | 2 +- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h index 9e0b189bdac89..757ef839f965a 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h @@ -191,7 +191,7 @@ class Vectorized { auto vals = svreinterpret_u16_bf16(values); vals = sveor_u16_x(ptrue, vals, mask); return svreinterpret_bf16_u16(vals); - }; + } Vectorized round() const; Vectorized tan() const; Vectorized tanh() const; @@ -349,47 +349,47 @@ Vectorized inline Vectorized::frac() const { return convert_float_bfloat16(v1, v2); \ } -DEFINE_BF16_FUNC_VIA_FLOAT(isnan); -DEFINE_BF16_FUNC_VIA_FLOAT(angle); -DEFINE_BF16_FUNC_VIA_FLOAT(acos); -DEFINE_BF16_FUNC_VIA_FLOAT(acosh); -DEFINE_BF16_FUNC_VIA_FLOAT(asin); -DEFINE_BF16_FUNC_VIA_FLOAT(atan); -DEFINE_BF16_FUNC_VIA_FLOAT(atanh); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); -DEFINE_BF16_FUNC_VIA_FLOAT(erf); -DEFINE_BF16_FUNC_VIA_FLOAT(erfc); -DEFINE_BF16_FUNC_VIA_FLOAT(exp); -DEFINE_BF16_FUNC_VIA_FLOAT(exp2); -DEFINE_BF16_FUNC_VIA_FLOAT(expm1); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); -DEFINE_BF16_FUNC_VIA_FLOAT(i0); -DEFINE_BF16_FUNC_VIA_FLOAT(i0e); -DEFINE_BF16_FUNC_VIA_FLOAT(digamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); -DEFINE_BF16_FUNC_VIA_FLOAT(log); -DEFINE_BF16_FUNC_VIA_FLOAT(log2); -DEFINE_BF16_FUNC_VIA_FLOAT(log10); -DEFINE_BF16_FUNC_VIA_FLOAT(log1p); -DEFINE_BF16_FUNC_VIA_FLOAT(sin); -DEFINE_BF16_FUNC_VIA_FLOAT(sinh); -DEFINE_BF16_FUNC_VIA_FLOAT(cos); -DEFINE_BF16_FUNC_VIA_FLOAT(cosh); -DEFINE_BF16_FUNC_VIA_FLOAT(ceil); -DEFINE_BF16_FUNC_VIA_FLOAT(floor); -DEFINE_BF16_FUNC_VIA_FLOAT(round); -DEFINE_BF16_FUNC_VIA_FLOAT(tan); -DEFINE_BF16_FUNC_VIA_FLOAT(tanh); -DEFINE_BF16_FUNC_VIA_FLOAT(trunc); -DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); -DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); -DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); -DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); +DEFINE_BF16_FUNC_VIA_FLOAT(isnan) +DEFINE_BF16_FUNC_VIA_FLOAT(angle) +DEFINE_BF16_FUNC_VIA_FLOAT(acos) +DEFINE_BF16_FUNC_VIA_FLOAT(acosh) +DEFINE_BF16_FUNC_VIA_FLOAT(asin) +DEFINE_BF16_FUNC_VIA_FLOAT(atan) +DEFINE_BF16_FUNC_VIA_FLOAT(atanh) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign) +DEFINE_BF16_FUNC_VIA_FLOAT(erf) +DEFINE_BF16_FUNC_VIA_FLOAT(erfc) +DEFINE_BF16_FUNC_VIA_FLOAT(exp) +DEFINE_BF16_FUNC_VIA_FLOAT(exp2) +DEFINE_BF16_FUNC_VIA_FLOAT(expm1) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot) +DEFINE_BF16_FUNC_VIA_FLOAT(i0) +DEFINE_BF16_FUNC_VIA_FLOAT(i0e) +DEFINE_BF16_FUNC_VIA_FLOAT(digamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter) +DEFINE_BF16_FUNC_VIA_FLOAT(log) +DEFINE_BF16_FUNC_VIA_FLOAT(log2) +DEFINE_BF16_FUNC_VIA_FLOAT(log10) +DEFINE_BF16_FUNC_VIA_FLOAT(log1p) +DEFINE_BF16_FUNC_VIA_FLOAT(sin) +DEFINE_BF16_FUNC_VIA_FLOAT(sinh) +DEFINE_BF16_FUNC_VIA_FLOAT(cos) +DEFINE_BF16_FUNC_VIA_FLOAT(cosh) +DEFINE_BF16_FUNC_VIA_FLOAT(ceil) +DEFINE_BF16_FUNC_VIA_FLOAT(floor) +DEFINE_BF16_FUNC_VIA_FLOAT(round) +DEFINE_BF16_FUNC_VIA_FLOAT(tan) +DEFINE_BF16_FUNC_VIA_FLOAT(tanh) +DEFINE_BF16_FUNC_VIA_FLOAT(trunc) +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma) +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt) +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal) +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow) Vectorized inline Vectorized::operator==( const Vectorized& other) const { diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 7587988528ebb..73f8c136794ce 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -293,7 +293,7 @@ struct ComputeLocationBase { , empty(size <= 0) {} inline Vec unnormalize(const Vec &in) const { - return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5); + return (in + Vec(static_cast(1))) * Vec(scaling_factor) - Vec(static_cast(0.5)); } inline Vec clip_coordinates(const Vec &in) const { @@ -831,7 +831,7 @@ struct ApplyGridSample(-0.75)); ApplyGridSample(const TensorAccessor& input) : inp_H(input.size(2)) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0e86e826405c6..e1cc43350b2b6 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1307,7 +1307,7 @@ endif() if(USE_MKLDNN_ACL) find_package(ACL REQUIRED) - target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS}) + target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS}) endif() target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) From d4dcd0354c4affcd90417f213785fc762e1b2b2f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 5 Nov 2025 19:43:40 +0800 Subject: [PATCH 064/651] [pytree][dynamo] add test to ensure `tree_map` preserves `dict` order (#166236) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166236 Approved by: https://github.com/mlazos --- test/dynamo/test_misc.py | 24 ++++++++++++++++ test/test_pytree.py | 18 ++++++++++++ torch/_dynamo/polyfills/pytree.py | 47 ++++++++++++++++++++++++++++--- 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b8727208a5bfa..b3e9df6a25cf3 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13194,6 +13194,30 @@ def fn(x, y): self.assertEqual(actual, expected) + @parametrize_pytree_module + def test_pytree_tree_map_dict_order(self, pytree): + def fn(tree): + new_tree = pytree.tree_map(lambda x: x, tree) + return list(new_tree.keys()), list(new_tree.values()) + + x = torch.randn(3, 2) + fn_opt = torch.compile(fullgraph=True)(fn) + + tree1 = {"b": x + 2, "a": x, "c": x - 1} + expected1 = fn(tree1) + actual1 = fn_opt(tree1) + self.assertEqual(actual1, expected1) + + tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)]) + expected2 = fn(tree2) + actual2 = fn_opt(tree2) + self.assertEqual(actual2, expected2) + + tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1}) + expected3 = fn(tree3) + actual3 = fn_opt(tree3) + self.assertEqual(actual3, expected3) + @parametrize_pytree_module def test_pytree_tree_map_only(self, pytree): if not callable(getattr(pytree, "tree_map_only", None)): diff --git a/test/test_pytree.py b/test/test_pytree.py index 7cc3b8affc0ef..09cf0bbd47a43 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -601,6 +601,24 @@ def f(x, y, z): for case in cases: run_test(case) + @parametrize_pytree_module + def test_tree_map_dict_order(self, pytree): + d = {"b": 2, "a": 1, "c": 3} + od = OrderedDict([("b", 2), ("a", 1), ("c", 3)]) + dd = defaultdict(int, {"b": 2, "a": 1, "c": 3}) + for tree in (d, od, dd): + result = pytree.tree_map(lambda x: x, tree) + self.assertEqual( + list(result.keys()), + list(tree.keys()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + self.assertEqual( + list(result.values()), + list(tree.values()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + @parametrize_pytree_module def test_tree_map_only(self, pytree): self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index d86fe054b2ebc..b4de3200e2960 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -6,7 +6,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import TypeIs import torch.utils._pytree as python_pytree @@ -24,9 +24,15 @@ __all__: list[str] = [] +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + if python_pytree._cxx_pytree_dynamo_traceable: import optree import optree._C + import optree.utils import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 @@ -600,14 +606,47 @@ def tree_map_( __all__ += ["tree_map_"] - _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined] + _none_registration = optree.register_pytree_node.get(type(None)) + assert _none_registration is not None @substitute_in_graph( # type: ignore[arg-type] - _none_unflatten, + _none_registration.unflatten_func, can_constant_fold_through=True, skip_signature_check=True, ) - def none_unflatten(_: None, children: Iterable[Any], /) -> None: + def none_unflatten(_: None, children: Iterable[_T], /) -> None: if len(list(children)) != 0: raise ValueError("Expected no children.") return None + + with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_flatten( + dct: dict[_KT, _VT], / + ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, + ) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values)) + return d # type: ignore[return-value] From 9c2c3dbc156a0eae1212ec3e51109d83a4922c9b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 17:12:30 +0000 Subject: [PATCH 065/651] Revert "Update triton to 3.5.1 release (#166968)" This reverts commit b4e4ee81d386db922d8f63359f9870eff1f44052. Reverted https://github.com/pytorch/pytorch/pull/166968 on behalf of https://github.com/malfet due to It might have caused deadlock/test timeouts, see https://hud.pytorch.org/hud/pytorch/pytorch/d4dcd0354c4affcd90417f213785fc762e1b2b2f/1?per_page=50&name_filter=trunk%20%2F%20linux-jammy-cuda12.8-py3.10-gcc11%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/166968#issuecomment-3492399396)) --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 7aab8bed1c108..10f1207e60e6c 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 +7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index d5c0c99142898..1545d966571dc 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.1 +3.5.0 From f93ee16fb68b480a183348b99445fb089f4a5c30 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 5 Nov 2025 17:19:24 +0000 Subject: [PATCH 066/651] [CI] Parse xml and upload json while running (#166988) Then we can point an ClickHouse ingestor at this s3 path and get them into ClickHouse while the job is running. use filelock to make sure each json is uploaded once so we don't end up with dups in ClickHouse Pull Request resolved: https://github.com/pytorch/pytorch/pull/166988 Approved by: https://github.com/izaitsevfb --- test/run_test.py | 18 +++++++- tools/stats/upload_test_stats.py | 6 ++- tools/testing/upload_artifacts.py | 70 ++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 4 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 448fbc28751f3..aa6a6d04cde3e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -73,7 +73,22 @@ ShardedTest, THRESHOLD, ) -from tools.testing.upload_artifacts import zip_and_upload_artifacts + + +try: + from tools.testing.upload_artifacts import ( + parse_xml_and_upload_json, + zip_and_upload_artifacts, + ) +except ImportError: + # some imports in those files might fail, e.g., boto3 not installed. These + # functions are only needed under specific circumstances (CI) so we can + # define dummy functions here. + def parse_xml_and_upload_json(): + pass + + def zip_and_upload_artifacts(failed: bool): + pass # Make sure to remove REPO_ROOT after import is done @@ -1887,6 +1902,7 @@ def run_tests( def handle_complete(failure: Optional[TestFailure]): failed = failure is not None if IS_CI and options.upload_artifacts_while_running: + parse_xml_and_upload_json() zip_and_upload_artifacts(failed) if not failed: return False diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index 6c0232c5e5a17..b2b0869d48350 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -38,12 +38,14 @@ def parse_xml_report( report: Path, workflow_id: int, workflow_run_attempt: int, + job_id: int | None = None, ) -> list[dict[str, Any]]: """Convert a test report xml file into a JSON-serializable list of test cases.""" print(f"Parsing {tag}s for test report: {report}") - job_id = get_job_id(report) - print(f"Found job id: {job_id}") + if job_id is None: + job_id = get_job_id(report) + print(f"Found job id: {job_id}") test_cases: list[dict[str, Any]] = [] diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index 07b62ec9a1b74..49d68fe9959ae 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -1,11 +1,16 @@ import glob import gzip +import json import os import time import zipfile from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, Optional + +from filelock import FileLock, Timeout + +from tools.stats.upload_test_stats import parse_xml_report REPO_ROOT = Path(__file__).resolve().parent.parent.parent @@ -140,3 +145,66 @@ def trigger_upload_test_stats_intermediate_workflow() -> None: }, ) print(x.text) + + +def parse_xml_and_upload_json() -> None: + """ + Parse xml test reports that do not yet have a corresponding json report + uploaded to s3, and upload the json reports to s3. Use filelock to avoid + uploading the same file from multiple processes. + """ + try: + job_id: Optional[int] = int(os.environ.get("JOB_ID", 0)) + if job_id == 0: + job_id = None + except (ValueError, TypeError): + job_id = None + + try: + for xml_file in glob.glob( + f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True + ): + xml_path = Path(xml_file) + json_file = xml_path.with_suffix(".json") + lock = FileLock(str(json_file) + ".lock") + + try: + lock.acquire(timeout=0) # immediately fails if already locked + if json_file.exists(): + continue # already uploaded + test_cases = parse_xml_report( + "testcase", + xml_path, + int(os.environ.get("GITHUB_RUN_ID", "0")), + int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")), + job_id, + ) + line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases]) + + gzipped = gzip.compress(line_by_line_jsons.encode("utf-8")) + s3_key = ( + json_file.relative_to(REPO_ROOT / "test/test-reports") + .as_posix() + .replace("/", "_") + ) + + get_s3_resource().put_object( + Body=gzipped, + Bucket="gha-artifacts", + Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}", + ContentType="application/json", + ContentEncoding="gzip", + ) + + # We don't need to save the json file locally, but doing so lets us + # track which ones have been uploaded already. We could probably also + # check S3 + with open(json_file, "w") as f: + f.write(line_by_line_jsons) + except Timeout: + continue # another process is working on this file + finally: + if lock.is_locked: + lock.release() + except Exception as e: + print(f"Failed to parse and upload json test reports: {e}") From 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23 Mon Sep 17 00:00:00 2001 From: karthickai Date: Mon, 3 Nov 2025 13:25:27 -0800 Subject: [PATCH 067/651] [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890) When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error. Fixes: #166888 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 14 ++++++++++++++ torch/_inductor/codecache.py | 6 ++++++ torch/_inductor/codegen/common.py | 11 +++++++++-- torch/_inductor/codegen/triton_utils.py | 5 +++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed8993a1c9a39..d0ff5799ac417 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14424,6 +14424,20 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) + @skip_if_halide + @requires_cuda_and_triton + def test_unbacked_float_item(self): + def fn(x, max_val): + return torch.clamp(x, 0, max_val.item()) + + self.common( + fn, + ( + torch.randn(10, 20, 30, device=self.device), + torch.tensor(5.0, device=self.device), + ), + ) + # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index cf17bf2e9478b..85702057cbb43 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache): throw std::runtime_error("expected int arg"); return reinterpret_cast(result); }} + template <> inline float parse_arg(PyObject* args, size_t n) {{ + auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1.0 && PyErr_Occurred())) + throw std::runtime_error("expected float arg"); + return static_cast(result); + }} {extra_parse_arg} diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 730c03f1c813c..3e9f174c810c5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1732,9 +1732,15 @@ def cpp_argdefs( call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") + if isinstance(outer, sympy.Symbol) and symbol_is_type( + outer, (SymT.UNBACKED_FLOAT) + ): + arg_defs.append(f"const float {inner}") + arg_types.append("const float") + else: + arg_defs.append(f"const {INDEX_TYPE} {inner}") + arg_types.append(f"const {INDEX_TYPE}") call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert not self.workspace_args, "Workspace not supported on CPU " @@ -2353,6 +2359,7 @@ def rename_indexing( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, + SymT.UNBACKED_FLOAT, ), ) } diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2a2706ad5720b..75a34813c876b 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -4,6 +4,7 @@ import sympy import torch +from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config from ..runtime.hints import AttrsDescriptorWrapper @@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" + elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( + arg.expr, (SymT.UNBACKED_FLOAT) + ): + return "fp32" elif isinstance(arg.expr, bool): return "i1" From 4ff068c33a0beda5df88cd373a4fb70b5a68e554 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Wed, 5 Nov 2025 17:59:12 +0000 Subject: [PATCH 068/651] [Code Clean] Replace `assert` with if statement and raise `AssertionError` (#166935) Including: - `torch/profiler/profiler.py` Fixes part of #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935 Approved by: https://github.com/fffrog, https://github.com/albanD --- torch/profiler/profiler.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index ee0ea85e1694b..893b4078cb9ce 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -210,7 +210,8 @@ def prepare_trace(self) -> None: def start_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.start() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before starting trace") self.profiler._start_trace() if self.profile_memory: @@ -256,7 +257,8 @@ def start_trace(self) -> None: def stop_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.stop() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before stopping trace") self.profiler.__exit__(None, None, None) def export_chrome_trace(self, path: str): @@ -264,7 +266,10 @@ def export_chrome_trace(self, path: str): Exports the collected trace in Chrome JSON format. If kineto is enabled, only last cycle in schedule is exported. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before exporting chrome trace" + ) if path.endswith(".gz"): fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) fp.close() @@ -284,7 +289,8 @@ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): path (str): save stacks file to this location; metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before exporting stacks") return self.profiler.export_stacks(path, metric) def toggle_collection_dynamic( @@ -316,7 +322,7 @@ def toggle_collection_dynamic( print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) """ - if not self.profiler: + if self.profiler is None: return self.profiler.toggle_collection_dynamic(enable, activities) @@ -333,7 +339,10 @@ def key_averages( To use shape/stack functionality make sure to set record_shapes/with_stack when creating profiler context manager. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before getting key averages" + ) return self.profiler.key_averages( group_by_input_shape, group_by_stack_n, group_by_overload_name ) @@ -343,7 +352,8 @@ def events(self): Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before accessing events") return self.profiler.function_events def add_metadata(self, key: str, value: str) -> None: @@ -395,7 +405,10 @@ def _memory_profile(self) -> MemoryProfile: if missing: raise ValueError(f"{', '.join(missing)} required for memory profiling.") - assert self.profiler is not None and self.profiler.kineto_results is not None + if self.profiler is None or self.profiler.kineto_results is None: + raise AssertionError( + "Profiler and kineto_results must be initialized for memory profiling" + ) return MemoryProfile(self.profiler.kineto_results) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: @@ -485,7 +498,8 @@ def schedule( """ def schedule_fn(step: int) -> ProfilerAction: - assert step >= 0 + if step < 0: + raise AssertionError(f"Step must be non-negative. Got {step}.") if step < skip_first: return ProfilerAction.NONE else: @@ -508,9 +522,11 @@ def schedule_fn(step: int) -> ProfilerAction: else ProfilerAction.RECORD_AND_SAVE ) - assert ( - wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 - ), "Invalid profiler schedule arguments" + if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0: + raise AssertionError( + f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), " + f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)." + ) if warmup == 0: warn( "Profiler won't be using warmup, this can skew profiler results", @@ -717,7 +733,8 @@ def __init__( activities_set.add(ProfilerActivity.CUDA) elif ProfilerActivity.CUDA in activities_set: activities_set.remove(ProfilerActivity.CUDA) - assert len(activities_set) > 0, "No valid profiler activities found" + if len(activities_set) == 0: + raise AssertionError("No valid profiler activities found") super().__init__( activities=activities, From c17aa0f11303bcd2cf617efd0cda6f3d38a1a34b Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 5 Nov 2025 18:03:59 +0000 Subject: [PATCH 069/651] [ROCm] Enable group gemm through CK (#166334) Fixes #161366 All the 4 types of dimension matrix are supported. 2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working for both forward and backward pass. The CK path is enabled for gfx942, gfx950. ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error, might require a different CK kernel config, based on the profiler result on gfx90a. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334 Approved by: https://github.com/atalman --- aten/src/ATen/native/cuda/GroupedBlas.cpp | 10 + aten/src/ATen/native/hip/ck_group_gemm.h | 19 + aten/src/ATen/native/hip/ck_group_gemm.hip | 462 +++++++++++++++++++++ test/test_matmul_cuda.py | 2 - 4 files changed, 491 insertions(+), 2 deletions(-) create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.h create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.hip diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index f64eb317d0cca..18ae048cfc968 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -22,6 +22,9 @@ #include #include #include +#ifdef USE_ROCM +#include +#endif #include #ifdef USE_FBGEMM_GENAI @@ -666,12 +669,19 @@ std::optional out_dtype) { // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm bool use_fast_path = false; + if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) { + use_fast_path = true; + } #endif const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); if (use_fast_path) { // fast path, no d2h sync needed +#ifndef USE_ROCM at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +#else + at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#endif } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } diff --git a/aten/src/ATen/native/hip/ck_group_gemm.h b/aten/src/ATen/native/hip/ck_group_gemm.h new file mode 100644 index 0000000000000..c50307c9f8ea3 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace hip { +namespace detail { +void group_gemm_ck( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + const std::optional& bias, + at::Tensor& out); + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip new file mode 100644 index 0000000000000..c436ad660c1c7 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -0,0 +1,462 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at { +namespace hip { +namespace detail { + +namespace CkTypes { + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F32 = float; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; +} + +template +using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< + ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, + DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, + CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + 1, 1, + S<1,32,1,8>, 4 +>; + +template +void launch_grouped_bgemm_ck_impl_dispatch( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + at::Tensor& out) +{ + using DeviceOp = GroupedGemmKernel; + using PassThrough = CkTypes::PassThrough; + + std::vector gemm_descs; + std::vector p_a_ptrs, p_b_ptrs; + std::vector p_e_ptrs; + // Note: d_ptrs will be resized after we populate the other vectors + + const int mat_a_dim = mat_a.dim(); + const int mat_b_dim = mat_b.dim(); + + const char* a_ptr_base = reinterpret_cast(mat_a.data_ptr()); + const char* b_ptr_base = reinterpret_cast(mat_b.data_ptr()); + char* out_ptr_base = reinterpret_cast(out.data_ptr()); + const size_t a_element_size = mat_a.element_size(); + const size_t b_element_size = mat_b.element_size(); + const size_t out_element_size = out.element_size(); + + // for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses. + if (mat_a_dim == 2 && mat_b_dim == 2) { + // 2D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + const int M = mat_a.size(0); // number of rows in A + const int N = mat_b.size(1); // number of columns in B + const int K = mat_a.size(1); // columns in A == rows in B + // for 2d*2d input, output is 3d. + // for each group, A columns (K) are sliced. M and N dimensions are not sliced. + for (int i = 0; i < num_groups; ++i) { + int start_k = (i == 0) ? 0 : offs_accessor[i-1]; + int end_k = offs_accessor[i]; + int k = end_k - start_k; + + //K dimension are sliced, hence select stride(1) always. + //K dimension is always dimension 1, regardless of memory layout (row/column major) + const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size; + const void* group_b_ptr; + int ldb; + + if (std::is_same::value) { + // Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size; + // Leading dimension = distance between rows = stride(0) + ldb = mat_b.stride(0); + } else { + // Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size; + // Leading dimension = distance between columns = stride(1) + ldb = mat_b.stride(1); + } + + // Calculate output pointer for group i in 3D tensor [num_groups, M, N] + // stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + int lda, ldc; + if (std::is_same::value) { + // Row-major A [M,K]: leading dimension = distance between rows = stride(0) + lda = mat_a.stride(0); + } else { + // Column-major A [M,K]: leading dimension = distance between columns = stride(1) + lda = mat_a.stride(1); + } + // Output is always row-major in 3D tensor [num_groups, M, N] + // Leading dimension for each group's [M,N] slice = stride(1) = N + ldc = out.stride(1); + size_t output_group_bytes = M * N * out_element_size; + void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes; + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(k), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 2 && mat_b_dim == 3) { + // 2D*3D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + + // 2d*3d input, output is 2d. + // A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n] + // Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B + const int K = mat_a.size(1); // columns in A + // For 2D-3D case: The output determines N (result width) + const int N = out.size(1); // N is the width of the output tensor + + for (int i = 0; i < num_groups; ++i) { + int start_m = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_m = offs_accessor[i]; + int m = end_m - start_m; + + // Skip zero-sized groups but continue processing subsequent groups + if (m <= 0) { + continue; + } + + // Select A rows for group i: skip start_m rows + const void* group_a_ptr; + int lda; + if (std::is_same::value) { + // Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + lda = mat_a.stride(0); // distance between rows + } else { + // Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows) + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + + // Detect stride pattern for A tensor to determine appropriate lda calculation + bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0)); + + if (a_is_strided_tensor) { + // For strided A tensors: stride(0) gives the actual leading dimension + lda = mat_a.stride(0); + } else { + // For non-strided A tensors: use the M dimension (total rows) + lda = mat_a.size(0); // Total M dimension for column-major layout + } + } + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + int ldb; + + if (std::is_same::value) { + // Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed + ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N] + } else { + // Detect stride pattern to determine appropriate ldb calculation + bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2)); + + if (is_strided_tensor) { + // For strided tensors: stride(2) gives the actual leading dimension + ldb = mat_b.stride(2); + } else { + // For non-strided tensors: use the N dimension + ldb = mat_b.size(1); + } + } + + // Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N] + void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size; + int ldc = out.stride(0); // distance between rows in output (should be N for 2D case) + + gemm_descs.push_back({ + static_cast(m), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 3) { + // 3d*3d input, output is 3d - batched matrix multiplication + // A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n] + // Each batch is processed as a separate GEMM operation + const int batch_size = mat_a.size(0); + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed) + + // Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout + int N; + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + N = mat_b.size(2); + } else if (mat_b.size(2) == K) { + // B is [batch, n, k] - transposed layout + N = mat_b.size(1); + } else { + TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[", + batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]"); + } + + for (int i = 0; i < batch_size; ++i) { + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + + // Select output batch for group i: Output[i, :, :] + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + + int lda, ldb, ldc; + + if (std::is_same::value) { + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + } else { + // Column-major A: leading dimension = distance between columns = stride(2) + lda = mat_a.stride(2); + } + + if (std::is_same::value) { + // Row-major B: leading dimension = distance between rows + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(1); // stride between K rows + } else { + // B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM + ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n]) + } + } else { + // Column-major B: leading dimension = distance between columns + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(2); // stride between N columns + } else { + // B is [batch, n, k] - transposed layout + ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n]) + } + } + + // Output is typically row-major: leading dimension = distance between rows = stride(1) + ldc = out.stride(1); + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 2) { + // 3D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + // 3d*2d input, output is 3d. + // A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both) + // Offset divides N dimension of B, each group gets different slice of B and different batch of A + const int batch_size = mat_a.size(0); // n_groups + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A + + // For row-major A and B case: B should be [K, total_N] + const int total_N = mat_b.size(1); // B is [K, total_N] for row-major + + for (int i = 0; i < num_groups; ++i) { + int start_n = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_n = offs_accessor[i]; + int n = end_n - start_n; + + // Skip zero-sized groups but continue processing subsequent groups + if (n <= 0) { + continue; + } + + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B slice for group i: B[:, start_n:end_n] (B[K, total_N]) + const void* group_b_ptr; + int ldb; + + // Check if B is row-major or column-major + if (std::is_same::value) { + // Row-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(0); // distance between rows (should be total_N) + } else { + // Column-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(1); // distance between columns (should be K) + } + + // Select output slice for group i: Output[:, start_n:end_n] + void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size; + + int lda, ldc; + + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + // Output is row-major: leading dimension = distance between rows = stride(0) + ldc = out.stride(0); + + gemm_descs.push_back({ + static_cast(M), + static_cast(n), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim); + } + + TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups"); + + // Initialize d_ptrs with the correct size + std::vector> d_ptrs(p_a_ptrs.size()); + + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = nullptr; + void* ws_buf = nullptr; + + hipMalloc(&gemm_arg_buf, arg_buf_size); + hipMalloc(&ws_buf, ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::hip::getCurrentHIPStream(); + invoker.Run(argument, {stream}); + hipFree(gemm_arg_buf); + hipFree(ws_buf); +} + +void group_gemm_ck( + const at::Tensor& input_a, + const at::Tensor& input_b_colmajor, + const std::optional& offs, + const std::optional& /*bias*/, + at::Tensor& out) +{ + // Detect if input_a is row-major based on stride pattern + bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1); + bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1); + // Ensure tensor A is row-major and contiguous if not already + at::Tensor mat_a = input_a; + if (!a_row_major) { + // If A is not row-major, make it contiguous (row-major) + mat_a = input_a.contiguous(); + } + // Force tensor B to be column-major using double transpose trick + // This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape + at::Tensor mat_b = input_b_colmajor; + if (!b_col_major) { + mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1); + } + + // For 3D tensors, check the last dimension stride for row-major detection + a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1); + bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1); + + if (mat_a.dtype() == at::kBFloat16) { + // bf16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kHalf) { + // fp16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kFloat) { + // fp32 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype"); + } + +} + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 1ba947befd9e7..10611d4f24673 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -490,8 +490,6 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): - if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]: - self.skipTest("failed using hipblaslt on rocm 6.4.2") device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 From c86540f12038ffc4a3c9ecdbecb01ce73e0967c9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 18:11:11 +0000 Subject: [PATCH 070/651] Revert "Add model code stack trace to torch.profile (#166677)" This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8. Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160)) --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ------------------ torch/autograd/profiler_util.py | 40 ---- torch/fx/graph.py | 23 --- torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +--------------- 6 files changed, 5 insertions(+), 425 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 12f6ba2228db8..a404e15a977ee 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index e12189dfea461..92d35fd8f49ad 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,12 +75,6 @@ ) from torch.testing._internal.jit_utils import JitTestCase -import json -import tempfile -from torch.profiler import profile, ProfilerActivity -from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace -from torch.autograd.profiler_util import _canonicalize_profiler_events - try: from torchvision import models as torchvision_models @@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor): print(x) -def _enrich_profiler_traces(prof): - """ - Helper function to extract and augment profiler events with stack traces. - - Args: - prof: A torch.profiler.profile object - - Returns: - A string representing enriched events - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: - trace_file = f.name - prof.export_chrome_trace(trace_file) - - with open(trace_file) as f: - trace_data = json.load(f) - - map_recorded_events_to_aten_ops_with_stack_trace( - trace_data - ) - - events = [] - for event in trace_data["traceEvents"]: - if "args" in event and "stack_trace" in event["args"]: - events.append(event) - - actual_traces = _canonicalize_profiler_events(events) - return actual_traces - - class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4248,150 +4212,6 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_stack_trace_augmentation(self): - """ - Test that map_recorded_events_to_aten_ops_with_stack_trace correctly - augments profiler events with stack traces from FX metadata registry. - """ - - # Simple test model - class TestModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(16, 10) - - def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x - - model = TestModel().cuda() - - # Compile the model - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda")) - - # Profile with the compiled model - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - - self.assertExpectedInline(actual_traces, """\ -event=aten::t node=t stack_trace=x = self.linear1(x) -event=aten::transpose node=t stack_trace=x = self.linear1(x) -event=aten::as_strided node=t stack_trace=x = self.linear1(x) -event=aten::addmm node=addmm stack_trace=x = self.linear1(x) -event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) -event=aten::relu node=relu stack_trace=x = self.relu(x) -event=aten::clamp_min node=relu stack_trace=x = self.relu(x) -event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) -event=aten::t node=t_1 stack_trace=x = self.linear2(x) -event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) -event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) -event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) -event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_multiple_modules(self): - """ - Test that multiple compiled modules under the same profiler session - have their events correctly augmented with stack traces. - """ - - class ModelA(torch.nn.Module): - def forward(self, x): - return x + 1 - - class ModelB(torch.nn.Module): - def forward(self, x): - return x - 1 - - model_a = ModelA().cuda() - model_b = ModelB().cuda() - - # Compile both models - compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) - compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_a(torch.randn(10, 10, device="cuda")) - _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - # Profile both models in the same session - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result_a = compiled_a(torch.randn(10, 10, device="cuda")) - result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::add node=add stack_trace=return x + 1 -event=cudaLaunchKernel node=add stack_trace=return x + 1 -event=aten::sub node=sub stack_trace=return x - 1 -event=cudaLaunchKernel node=sub stack_trace=return x - 1""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_nested_graph_modules(self): - """ - Test that nested graph modules (e.g., graph modules calling subgraphs) - have their events correctly augmented with stack traces. - """ - - # Model with nested structure - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.c = 5 - - @torch.compiler.nested_compile_region - def forward(self, x, y): - m = torch.mul(x, y) - s = m.sin() - a = s + self.c - return a - - model = Mod().cuda() - - # Compile the model (this may create nested graph modules) - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - # Profile - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::mul node=mul stack_trace=m = torch.mul(x, y) -event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) -event=aten::sin node=sin stack_trace=s = m.sin() -event=cudaLaunchKernel node=sin stack_trace=s = m.sin() -event=aten::add node=add stack_trace=a = s + self.c -event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" - ) - def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index a61aee321fcff..b2d6530049e61 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,43 +1224,3 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) - - -# Collect all events with stack traces and format them canonically -def _canonicalize_profiler_events(events): - """ - Extract and format all events with stack traces in a canonical way - for deterministic testing. - """ - events_with_traces = [] - - for event in events: - # Extract relevant fields - event_name = event.get("name", "") - node_name = event["args"].get("node_name", "") - stack_trace = event["args"].get("stack_trace", "") - - # Get the last non-empty line of the stack trace - lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] - stack_trace = lines[-1] if lines else "" - - events_with_traces.append( - { - "event_name": event_name[:20], - "node_name": node_name, - "stack_trace": stack_trace, - "start_time": event.get("ts", 0), - } - ) - - # Sort by node_name for deterministic ordering - events_with_traces.sort(key=lambda x: x["start_time"]) - - # Format as a string - lines: list[str] = [] - for evt in events_with_traces: - lines.append( - f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" - ) - - return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fd6835d2b301b..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,7 +443,6 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -799,10 +798,6 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") - if record_func: - body.append( - "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" - ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -812,22 +807,8 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") - do_record = record_func and node.op in ( - "call_function", - "call_method", - "call_module", - ) - if do_record: - # The double hash ## convention is used by post-processing to find the fx markers - body.append( - f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" - ) emit_node(node) delete_unused_values(node) - if do_record: - body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") - if record_func: - body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,7 +1760,6 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1847,7 +1827,6 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def _python_code( @@ -1860,7 +1839,6 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1871,7 +1849,6 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,18 +861,14 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - - from torch._dynamo import config as dynamo_config - - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata - ) + python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -889,6 +885,7 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -908,13 +905,6 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) - # Replace the placeholder in generated code with actual filename - # The double hash ## convention is used by post-processing to find the fx markers - self._code = self._code.replace( - "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", - f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", - ) - cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c6e06b2cb3c9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None: with profile(): pass - - -@dataclass -class TimelineEvent: - """Represents an event in the profiler timeline.""" - - timestamp: int - event_type: Literal["start", "end", "regular"] - marker_type: Optional[Literal["filename", "node"]] - identifier: Optional[str | int] - event: dict[str, Any] - - -@dataclass -class ContextStackEntry: - """Represents a context (filename or node) in the stack.""" - - context_type: Literal["filename", "node"] - identifier: str | int - metadata: Optional[dict] - tid: Optional[int] = None # Thread ID associated with this context - - -def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): - """ - Maps recorded profiler events to their corresponding fx nodes and adds stack traces. - - Builds a timeline of all events (regular ops and FX markers for filenames/nodes), - sorts by timestamp, then processes chronologically while maintaining a context stack of active - filename/node scopes. Regular events are augmented with stack traces and node names from the - innermost active context. Runtime is O(n log n) for n events. - - Args: - traced_data: Json of profiler events from Chrome trace - - Returns: - Dict mapping recorded event names to their aten operations with added stack traces - """ - from torch.fx.traceback import _FX_METADATA_REGISTRY - - trace_events = traced_data.get("traceEvents", []) - - # Create event timeline - event_timeline: list[TimelineEvent] = [] - - def is_fx_marker_event(event): - return ( - event.get("cat") == "cpu_op" - and event.get("name", "").startswith("## ") - and event.get("name", "").endswith(" ##") - ) - - def append_fx_marker_event(event_type, identifier, event): - start_ts = event["ts"] - end_ts = start_ts + event["dur"] - event_timeline.append( - TimelineEvent(start_ts, "start", event_type, identifier, event) - ) - event_timeline.append( - TimelineEvent(end_ts, "end", event_type, identifier, event) - ) - - for event in trace_events: - if "ts" not in event or "dur" not in event: - continue - - if is_fx_marker_event(event): - content = event["name"][3:-3] - - if content.endswith(".py"): - append_fx_marker_event("filename", content, event) - else: - try: - node_index = int(content) - except ValueError: - pass - append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] - - else: - # Regular event that needs augmentation - start_ts = event["ts"] - event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) - - # Sort by timestamp - event_timeline.sort(key=lambda x: x.timestamp) - - # Process events in chronological order with a stack - context_stack: list[ContextStackEntry] = [] - - # Invariant: all start event has a corresponding end event - for timeline_event in event_timeline: - match timeline_event.event_type: - case "start": - assert timeline_event.identifier is not None - - if timeline_event.marker_type == "filename": - assert isinstance(timeline_event.identifier, str) - # Push filename context - query metadata registry on-demand - metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) - tid = timeline_event.event.get("tid") - context_stack.append( - ContextStackEntry( - "filename", timeline_event.identifier, metadata, tid - ) - ) - elif timeline_event.marker_type == "node": - # Find the current filename from stack - current_file_metadata = None - tid = timeline_event.event.get("tid") - for ctx_entry in reversed(context_stack): - if ( - ctx_entry.context_type == "filename" - and ctx_entry.tid == tid - ): - current_file_metadata = ctx_entry.metadata - break - - if current_file_metadata: - node_metadata = current_file_metadata.get("node_metadata", {}) - if timeline_event.identifier in node_metadata: - node_meta: Optional[dict] = node_metadata[ - timeline_event.identifier - ] - context_stack.append( - ContextStackEntry( - "node", timeline_event.identifier, node_meta, tid - ) - ) - - case "end": - # Pop from stack - search backwards to find matching context - for i in range(len(context_stack) - 1, -1, -1): - ctx_entry = context_stack[i] - if ( - timeline_event.marker_type == ctx_entry.context_type - and timeline_event.identifier == ctx_entry.identifier - ): - context_stack.pop(i) - break - - case "regular": - # Apply metadata from current context stack - # Find the most specific context (node takes precedence over filename) - # Only augment events with the same tid as the file/node event matched - current_stack_trace = None - current_node_name = None - event_tid = timeline_event.event.get("tid") - - for ctx_entry in reversed(context_stack): - # Only apply metadata from contexts with matching tid - if ctx_entry.tid == event_tid: - if ctx_entry.context_type == "node" and ctx_entry.metadata: - current_stack_trace = ctx_entry.metadata.get( - "stack_trace", "No model stack trace available" - ) - current_node_name = ctx_entry.metadata.get("name", "") - # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes - # if nodes are nested, e.g. in nested graph modules - break - - # Augment the event - if current_stack_trace or current_node_name: - args = timeline_event.event.setdefault("args", {}) - if current_stack_trace: - args["stack_trace"] = current_stack_trace - if current_node_name: - args["node_name"] = current_node_name From ad5c7c20e0dd55baa23a597cf10ffe7422b5cabf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 18:13:57 +0000 Subject: [PATCH 071/651] Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)" This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578. Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312)) --- .ci/docker/common/install_cuda.sh | 2 +- .ci/pytorch/smoke_test/smoke_test.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index fe0cb8cc79c4f..fe2f9ae3185a3 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -129,7 +129,7 @@ function install_129 { } function install_128 { - CUDNN_VERSION=9.10.2.21 + CUDNN_VERSION=9.8.0.87 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 3642f29684cf0..675d58a3e283d 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -272,18 +272,6 @@ def smoke_test_cuda( torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") - torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion() - print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}") - torch_cudnn_runtime_version = tuple( - [int(x) for x in torch_cudnn_version.split(".")] - ) - if torch_cudnn_runtime_version != torch_cudnn_compile_version: - raise RuntimeError( - "cuDNN runtime version doesn't match comple version. " - f"Loaded: {torch_cudnn_runtime_version} " - f"Expected: {torch_cudnn_compile_version}" - ) - if sys.platform in ["linux", "linux2"]: torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") From dcc2ba4ca48512968e027e765695490476d717dc Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Nov 2025 06:52:49 -0800 Subject: [PATCH 072/651] Add some code for exploring the space of accessible size/stride configs via plain views (#167076) We are working on a translation from as_strided to view operations, but only when the as_strided is representable as a plain view. A useful testing utility in this situation is the ability to enumerate all valid views on an original tensor. So we have a small test here that shows it is possible. To avoid an explosion of states, we don't handle permutes and size=1, which are degenerate cases (you can always do a single permute and a series of unsqueezes to get to the final desired state.) Authored with claude code assistance. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076 Approved by: https://github.com/albanD ghstack dependencies: #166868, #166867 --- test/test_as_strided.py | 176 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 test/test_as_strided.py diff --git a/test/test_as_strided.py b/test/test_as_strided.py new file mode 100644 index 0000000000000..a5bcb8e279247 --- /dev/null +++ b/test/test_as_strided.py @@ -0,0 +1,176 @@ +# Owner(s): ["oncall: pt2"] + +from collections import deque +from typing import Optional + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]: + """Extract (sizes, strides) tuple from a tensor.""" + return (tuple(t.size()), tuple(t.stride())) + + +def enumerate_reachable_states( + initial_size: int, +) -> set[tuple[tuple[int, ...], tuple[int, ...]]]: + """ + Use BFS with DP to enumerate all reachable (size, stride) states from + a 1D contiguous tensor via valid view operations. + + We only explore states with offset=0 (you can retroactively change the offset). + We reject states with size=0 or size=1 dimensions as they are degenerate. + """ + # Create initial 1D contiguous tensor + initial_tensor = torch.arange(initial_size) + + initial_state = get_state(initial_tensor) + + # Map from state to tensor for that state + state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = { + initial_state: initial_tensor + } + visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state} + queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state]) + + while queue: + state = queue.popleft() + t = state_to_tensor[state] + sizes, strides = state + ndim = len(sizes) + + def add_state(new_t: torch.Tensor) -> None: + new_state = get_state(new_t) + sizes, strides = new_state + # Skip if has size-0 or size-1 dimensions + if any(s == 0 or s == 1 for s in sizes): + return + # Only accept states where strides are in descending order + if list(strides) != sorted(strides, reverse=True): + return + if new_state not in visited: + visited.add(new_state) + queue.append(new_state) + state_to_tensor[new_state] = new_t + + # 1. Unflatten: try factoring each dimension + for dim in range(ndim): + size = sizes[dim] + assert size > 1 + # Try all factorizations x * y = size where both x, y >= 2 + # We only need to check x up to size // 2 since when x > size // 2, + # y = size // x < 2, which we reject + for x in range(2, size // 2 + 1): + if size % x == 0: + y = size // x + add_state(t.unflatten(dim, (x, y))) + + # 2. Slice: exhaustively check all possible slicing parameters + for dim in range(ndim): + size = sizes[dim] + for start in range(size): + for stop in range(start + 1, size + 1): + for step in range(1, size + 1): + slices = [slice(None)] * ndim + slices[dim] = slice(start, stop, step) + add_state(t[tuple(slices)]) + + # 3. Flatten: merge adjacent dimensions + for dim in range(ndim - 1): + add_state(t.flatten(dim, dim + 1)) + + return visited + + +class TestAsStrided(TestCase): + def test_size_10_exhaustive(self) -> None: + """Test that size 10 produces exactly the expected 54 states.""" + expected_states = { + ((2,), (1,)), + ((2,), (2,)), + ((2,), (3,)), + ((2,), (4,)), + ((2,), (5,)), + ((2,), (6,)), + ((2,), (7,)), + ((2,), (8,)), + ((2,), (9,)), + ((2, 2), (2, 1)), + ((2, 2), (3, 1)), + ((2, 2), (3, 2)), + ((2, 2), (4, 1)), + ((2, 2), (4, 2)), + ((2, 2), (4, 3)), + ((2, 2), (5, 1)), + ((2, 2), (5, 2)), + ((2, 2), (5, 3)), + ((2, 2), (5, 4)), + ((2, 2), (6, 1)), + ((2, 2), (6, 2)), + ((2, 2), (6, 3)), + ((2, 2), (8, 1)), + ((2, 2, 2), (4, 2, 1)), + ((2, 2, 2), (5, 2, 1)), + ((2, 3), (3, 1)), + ((2, 3), (4, 1)), + ((2, 3), (5, 1)), + ((2, 3), (5, 2)), + ((2, 3), (6, 1)), + ((2, 4), (4, 1)), + ((2, 4), (5, 1)), + ((2, 5), (5, 1)), + ((3,), (1,)), + ((3,), (2,)), + ((3,), (3,)), + ((3,), (4,)), + ((3, 2), (2, 1)), + ((3, 2), (3, 1)), + ((3, 2), (3, 2)), + ((3, 2), (4, 1)), + ((3, 3), (3, 1)), + ((4,), (1,)), + ((4,), (2,)), + ((4,), (3,)), + ((4, 2), (2, 1)), + ((5,), (1,)), + ((5,), (2,)), + ((5, 2), (2, 1)), + ((6,), (1,)), + ((7,), (1,)), + ((8,), (1,)), + ((9,), (1,)), + ((10,), (1,)), + } + + actual_states = enumerate_reachable_states(10) + + self.assertEqual(len(actual_states), 54) + self.assertEqual(actual_states, expected_states) + + def test_subset_property(self) -> None: + """ + Test that for sizes 2..10, each smaller tensor results in a strict + subset of possible states compared to the next one. + """ + prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None + for size in range(2, 11): + current_states = enumerate_reachable_states(size) + + if prev_states is not None: + # Check that prev_states is a strict subset of current_states + self.assertTrue( + prev_states.issubset(current_states), + f"States from size {size - 1} are not a subset of size {size}", + ) + # Check that it's a strict subset (not equal) + self.assertTrue( + len(prev_states) < len(current_states), + f"States from size {size - 1} should be strictly fewer than size {size}", + ) + + prev_states = current_states + + +if __name__ == "__main__": + run_tests() From 89165c0a2b5d3c147c19a492437291c8ff18aa7f Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Wed, 5 Nov 2025 18:26:31 +0000 Subject: [PATCH 073/651] Update triton to 3.5.1 release (#166968) This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968 Approved by: https://github.com/Lucaskabela, https://github.com/njriasan --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 10f1207e60e6c..7aab8bed1c108 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571dc..d5c0c99142898 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.0 +3.5.1 From 641de23c96e2c0d2848a7aa2aacb2f77843177a5 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 5 Nov 2025 17:05:14 +0000 Subject: [PATCH 074/651] ci: Add aarch64 docker builds for modern clang (#166416) Should enable us to build using some arm optimizations that are only available on the newest versions of clang. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416 Approved by: https://github.com/malfet --- .ci/docker/build.sh | 10 ++++++++++ .ci/docker/common/install_clang.sh | 4 ++-- .ci/docker/common/install_openblas.sh | 1 + .github/workflows/docker-builds.yml | 2 ++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d0500b89780ce..5257decb9d4d5 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -271,6 +271,16 @@ case "$tag" in # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; + pytorch-linux-jammy-aarch64-py3.10-clang21) + ANACONDA_PYTHON_VERSION=3.10 + CLANG_VERSION=21 + ACL=yes + VISION=yes + OPENBLAS=yes + # snadampal: skipping llvm src build install because the current version + # from pytorch/llvm:9.0.1 is x86 specific + SKIP_LLVM_SRC_BUILD_INSTALL=yes + ;; pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 diff --git a/.ci/docker/common/install_clang.sh b/.ci/docker/common/install_clang.sh index 1cb216edf1b38..93daeee919b3d 100755 --- a/.ci/docker/common/install_clang.sh +++ b/.ci/docker/common/install_clang.sh @@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then # work around ubuntu apt-get conflicts sudo apt-get -y -f install wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - if [[ $CLANG_VERSION == 18 ]]; then - apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main" + if [[ $CLANG_VERSION -ge 18 ]]; then + apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main" fi fi diff --git a/.ci/docker/common/install_openblas.sh b/.ci/docker/common/install_openblas.sh index 2f386c6bd523a..5a28068781245 100755 --- a/.ci/docker/common/install_openblas.sh +++ b/.ci/docker/common/install_openblas.sh @@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" - OPENBLAS_CHECKOUT_DIR="OpenBLAS" OPENBLAS_BUILD_FLAGS=" +CC=gcc NUM_THREADS=128 USE_OPENMP=1 NO_SHARED=0 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 6fbe2e846d40b..4d0940094f541 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -79,6 +79,8 @@ jobs: include: - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 runner: linux.arm64.m7g.4xlarge + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 + runner: linux.arm64.m7g.4xlarge - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 From 14b153bcf28efa7056f8b0ecf2e8c7def97aa2ea Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 5 Nov 2025 08:00:51 -0800 Subject: [PATCH 075/651] include DTensor metadata when pretty-printing fx.Graphs (#166750) Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it: image ``` import torch from torch.testing._internal.distributed.fake_pg import FakeStore from torch.distributed.tensor import distribute_tensor, Shard, Replicate from torch.utils._debug_mode import DebugMode from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._python_dispatch import TorchDispatchMode from torch.utils import _pytree as pytree world_size = 8 device_type = "cpu" fake_store = FakeStore() torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size) device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,)) dim = 128 A = torch.randn(8, dim) B = torch.randn(dim, dim) dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_() dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_() def f(dA, dB): dy = dA @ dB loss = dy.sum() loss.backward() return dA.grad, dB.grad # We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode. # make_fx has some logic to ensure we don't accidentally stash real tensors in the graph # so we won't stash our DTensors properly if they don't hold Fake inner tensors gm = make_fx(f, tracing_mode='fake')(dA, dB) # DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph gm.graph.eliminate_dead_code() gm.recompile() gm.print_readable(colored=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750 Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007 --- .../tensor/debug/test_debug_mode.py | 35 ++++++++++++++++++- torch/fx/graph.py | 19 ++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 9acfcb15804e5..abc37f17a74de 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -5,8 +5,16 @@ import torch import torch.distributed as dist from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor import ( + DeviceMesh, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -426,6 +434,31 @@ def forward(self, x): ][-1] self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace) + def test_pretty_print_dtensor_make_fx(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + A = torch.randn(8, 32) + B = torch.randn(32, 32) + dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_() + dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_() + + def f(dA, dB): + dy = dA @ dB + loss = dy.sum() + loss.backward() + return dA.grad, dB.grad + + # We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode. + # make_fx has some logic to ensure we don't accidentally stash real tensors in the graph + # so we won't stash our DTensors properly if they don't hold Fake inner tensors + gm = make_fx(f, tracing_mode="fake")(dA, dB) + # DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph + gm.graph.eliminate_dead_code() + gm.recompile() + # Colored is nice for actual viewing, not using in this test though + gm_str = gm.print_readable(colored=False, print_output=False) + self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str) + instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..899a50f0f4142 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -647,6 +647,15 @@ def emit_node(node: Node): if verbose: # override annotation with more detailed information + try: + from torch.distributed.tensor._api import DTensor, DTensorSpec + + dtensorspec_format_shard_order_str = ( + DTensorSpec.format_shard_order_str + ) + except ModuleNotFoundError: + DTensor = None # type: ignore[assignment,misc] + dtensorspec_format_shard_order_str = None from torch.fx.experimental.proxy_tensor import py_sym_types from torch.fx.passes.shape_prop import TensorMetadata @@ -677,6 +686,16 @@ def _tensor_annotation(t: torch.Tensor) -> str: core = _tensor_annotation(meta_val) if is_plain: maybe_type_annotation = f': "{core}"' + elif type(meta_val) is DTensor: + assert dtensorspec_format_shard_order_str is not None + dtensor_meta = dtensorspec_format_shard_order_str( + meta_val._spec.placements, # type: ignore[attr-defined] + meta_val._spec.shard_order, # type: ignore[attr-defined] + ) + cls = meta_val.__class__.__name__ + maybe_type_annotation = ( + f': "{cls}({core}, {dim_green(dtensor_meta)})"' + ) else: cls = meta_val.__class__.__name__ maybe_type_annotation = f': "{cls}({core})"' From 6052a01b71277eb767d87daf47d109f8e0edd5c0 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 5 Nov 2025 19:18:35 +0000 Subject: [PATCH 076/651] [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022) Provides type coverage to torch/_dynamo/variables/dicts.py Coverage report: `mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022 Approved by: https://github.com/Skylion007 --- torch/_dynamo/symbolic_convert.py | 6 +- torch/_dynamo/variables/dicts.py | 358 +++++++++++++++++------------- 2 files changed, 208 insertions(+), 156 deletions(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 53ec0ee412849..3943f90b0020a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3320,7 +3320,7 @@ def SET_ADD(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "add", [v], {}) + obj.call_method(self, "add", [v], {}) # type: ignore[arg-type] def SET_UPDATE(self, inst: Instruction) -> None: v = self.pop() @@ -3329,7 +3329,7 @@ def SET_UPDATE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] def LIST_APPEND(self, inst: Instruction) -> None: v = self.pop() @@ -3637,7 +3637,7 @@ def DICT_MERGE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg].realize() assert isinstance(obj, ConstDictVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] DICT_UPDATE = DICT_MERGE diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f70ba99c0c93d..fb212c3326222 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Dictionary-related variable tracking classes for PyTorch Dynamo. @@ -26,7 +24,7 @@ import operator import types from collections.abc import Hashable as py_Hashable -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from torch._subclasses.fake_tensor import is_fake @@ -59,11 +57,13 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl -def was_instancecheck_override(obj): +def was_instancecheck_override(obj: Any) -> bool: return type(obj).__dict__.get("__instancecheck__", False) -def raise_unhashable(arg, tx=None): +def raise_unhashable( + arg: VariableTracker, tx: Optional["InstructionTranslator"] = None +) -> None: if tx is None: from torch._dynamo.symbolic_convert import InstructionTranslator @@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None): ) -def is_hashable(x): +def is_hashable(x: VariableTracker) -> bool: # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at # the underlying value without realizing the VT. Consider updating the @@ -143,7 +143,7 @@ class _HashableTracker: Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing """ - def __init__(self, vt) -> None: + def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) # TODO Temporarily remove to figure out what keys are we breaking on @@ -153,7 +153,7 @@ def __init__(self, vt) -> None: self.vt = vt @property - def underlying_value(self): + def underlying_value(self) -> Any: if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() @@ -178,7 +178,8 @@ def underlying_value(self): elif isinstance(self.vt, variables.FrozenDataClassVariable): Hashable = ConstDictVariable._HashableTracker fields_values = { - k: Hashable(v).underlying_value for k, v in self.vt.fields.items() + k: Hashable(v).underlying_value + for k, v in self.vt.fields.items() # type: ignore[attr-defined] } return variables.FrozenDataClassVariable.HashWrapper( self.vt.python_type(), fields_values @@ -187,16 +188,16 @@ def underlying_value(self): # The re module in Python 3.13+ has a dictionary (_cache2) with # an object as key (`class _ZeroSentinel(int): ...`): # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual - return self.vt.value + return self.vt.value # type: ignore[attr-defined,union-attr] else: x = self.vt.as_python_constant() return x - def __hash__(self): + def __hash__(self) -> int: return hash(self.underlying_value) @staticmethod - def _eq_impl(a, b): + def _eq_impl(a: Any, b: Any) -> bool: # TODO: Put this in utils and share it between variables/builtin.py and here type_a, type_b = type(a), type(b) if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): @@ -212,7 +213,7 @@ def _eq_impl(a, b): else: return a == b - def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: + def __eq__(self, other: object) -> bool: Hashable = ConstDictVariable._HashableTracker assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( type(other) @@ -226,8 +227,8 @@ def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: def __init__( self, items: dict[VariableTracker, VariableTracker], - user_cls=dict, - **kwargs, + user_cls: type = dict, + **kwargs: Any, ) -> None: # .clone() pass these arguments in kwargs but they're recreated a few # lines below @@ -247,18 +248,22 @@ def __init__( for x, v in items.items() ) - def make_hashable(key): + def make_hashable( + key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], + ) -> "ConstDictVariable._HashableTracker": return key if isinstance(key, Hashable) else Hashable(key) dict_cls = self._get_dict_cls_from_user_cls(user_cls) self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed - self.should_reconstruct_all = not is_from_local_source(self.source) + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) self.original_items = items.copy() self.user_cls = user_cls - def _get_dict_cls_from_user_cls(self, user_cls): + def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) # avoid executing user code if user_cls is a dict subclass @@ -277,10 +282,10 @@ def _get_dict_cls_from_user_cls(self, user_cls): dict_cls = dict return dict_cls - def as_proxy(self): + def as_proxy(self) -> dict[Any, Any]: return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} - def debug_repr(self): + def debug_repr(self) -> str: return ( "{" + ", ".join( @@ -289,20 +294,20 @@ def debug_repr(self): + "}" ) - def as_python_constant(self): + def as_python_constant(self) -> dict[Any, Any]: return { k.vt.as_python_constant(): v.as_python_constant() for k, v in self.items.items() } - def keys_as_python_constant(self): + def keys_as_python_constant(self) -> dict[Any, VariableTracker]: self.install_dict_keys_match_guard() return {k.vt.as_python_constant(): v for k, v in self.items.items()} - def python_type(self): + def python_type(self) -> type: return self.user_cls - def __contains__(self, vt) -> bool: + def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( @@ -322,13 +327,15 @@ def has_new_items(self) -> bool: for key, value in self.items.items() ) - def is_new_item(self, value, other): + def is_new_item( + self, value: Optional[VariableTracker], other: VariableTracker + ) -> bool: # compare the id of the realized values if both values are not lazy VTs if value and value.is_realized() and other.is_realized(): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct_kvs_into_new_dict(self, codegen): + def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: # Build a dictionary that contains the keys and values. num_args = 0 for key, value in self.items.items(): @@ -340,7 +347,7 @@ def reconstruct_kvs_into_new_dict(self, codegen): num_args += 1 codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: if self.user_cls is collections.OrderedDict: # emit `OrderedDict(constructed_dict)` codegen.add_push_null( @@ -358,19 +365,21 @@ def reconstruct(self, codegen: "PyCodegen"): def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker - ): + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: raise_observed_exception(KeyError, tx) return self.items[key] - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: - msg = f"Dictionary key {arg.value} not found during tracing" + msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] unimplemented_v2( gb_type="key not found in dict", - context=f"Key {arg.value}", + context=f"Key {arg.value}", # type: ignore[attr-defined] explanation=msg, hints=[ "Check if the key exists in the dictionary before accessing it.", @@ -379,13 +388,13 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): ) return self.items[key] - def maybe_getitem_const(self, arg: VariableTracker): + def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: return None return self.items[key] - def realize_key_vt(self, arg: VariableTracker): + def realize_key_vt(self, arg: VariableTracker) -> None: # Realize the LazyVT on a particular index assert arg in self key = ConstDictVariable._HashableTracker(arg) @@ -394,11 +403,13 @@ def realize_key_vt(self, arg: VariableTracker): if isinstance(original_key_vt, variables.LazyVariableTracker): original_key_vt.realize() - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: if self.source: install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Key guarding - These are the cases to consider # 1) The dict has been mutated. In this case, we would have already # inserted a DICT_KEYS_MATCH guard, so we can skip. @@ -439,11 +450,11 @@ def install_dict_contains_guard(self, tx, args): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # NB - Both key and value are LazyVariableTrackers in the beginning. So, # we have to insert guards when a dict method is accessed. For this to # be simple, we are conservative and overguard. We skip guard only for @@ -462,7 +473,7 @@ def call_method( tx, *args, **kwargs ) tx.output.side_effects.mutation(self) - self.items.update(temp_dict_vt.items) + self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "__getitem__": # Key guarding - Nothing to do. LazyVT for value will take care. @@ -526,7 +537,7 @@ def call_method( return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) != 2: @@ -550,7 +561,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: self.install_dict_contains_guard(tx, args) @@ -565,7 +576,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: # missing item, return the default value. Install no DICT_CONTAINS guard. @@ -599,7 +610,7 @@ def call_method( last = v.value else: raise_args_mismatch(tx, name) - k, v = self.items.popitem(last=last) + k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] else: k, v = self.items.popitem() @@ -632,17 +643,17 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() - dict_vt = args[0] + dict_vt: ConstDictVariable = args[0] else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) - self.items.update(dict_vt.items) + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + self.items.update(dict_vt.items) # type: ignore[attr-defined] if has_kwargs: # Handle kwargs - kwargs = { + kwargs_hashable = { Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() } - self.items.update(kwargs) + self.items.update(kwargs_hashable) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) @@ -656,7 +667,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_contains_guard(tx, args) contains = args[0] in self @@ -671,7 +682,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) > 2: @@ -707,7 +718,7 @@ def call_method( and "last" in kwargs and isinstance(kwargs["last"], ConstantVariable) ): - last = kwargs.get("last").value + last = kwargs.get("last").value # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) @@ -723,7 +734,7 @@ def call_method( ) elif name == "__ne__": return ConstantVariable.create( - not self.call_method(tx, "__eq__", args, kwargs).value + not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] ) elif name == "__or__": if len(args) != 1: @@ -750,14 +761,14 @@ def call_method( if not istype( other, (ConstDictVariable, variables.UserDefinedDictVariable) ): - msg = ( + err_msg = ( f"unsupported operand type(s) for |: '{self.python_type().__name__}'" f"and '{other.python_type().__name__}'" ) - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception(TypeError, tx, args=[err_msg]) # OrderedDict overloads __ror__ - ts = {self.user_cls, other.user_cls} + ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] user_cls = ( collections.OrderedDict if any(issubclass(t, collections.OrderedDict) for t in ts) @@ -774,8 +785,8 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. - args[0].install_dict_keys_match_guard() - new_dict_vt.items.update(args[0].items) + args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] + new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] return new_dict_vt elif name == "__ior__": self.call_method(tx, "update", args, kwargs) @@ -789,11 +800,13 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: self.install_dict_keys_match_guard() return [x.vt for x in self.items.keys()] - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( @@ -816,25 +829,25 @@ def call_obj_hasattr(self, tx, name): ], ) - def clone(self, **kwargs): + def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict - def python_type(self): + def python_type(self) -> type: return types.MappingProxyType - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # load types.MappingProxyType if self.source: msg = ( @@ -863,11 +876,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.source and tx.output.side_effects.has_existing_dict_mutation(): msg = ( "A dict has been modified while we have an existing mappingproxy object. " @@ -892,7 +905,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) @@ -900,35 +913,44 @@ def call_obj_hasattr( class NNModuleHooksDictVariable(ConstDictVariable): # Special class to avoid adding any guards on the nn module hook ids. - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: pass class DefaultDictVariable(ConstDictVariable): - def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type, + default_factory: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict if default_factory is None: default_factory = ConstantVariable.create(None) self.default_factory = default_factory - def is_python_constant(self): + def is_python_constant(self) -> bool: # Return false for unsupported defaults. This ensures that a bad handler # path is not taken in BuiltinVariable for getitem. if self.default_factory not in [list, tuple, dict] and not self.items: return False return super().is_python_constant() - def debug_repr(self): + def debug_repr(self) -> str: + assert self.default_factory is not None return ( f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" ) @staticmethod - def is_supported_arg(arg): + def is_supported_arg(arg: VariableTracker) -> bool: if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: @@ -942,11 +964,11 @@ def is_supported_arg(arg): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") @@ -962,13 +984,13 @@ def call_method( else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( - tx, "__setitem__", (args[0], default_var), kwargs + tx, "__setitem__", [args[0], default_var], kwargs ) return default_var else: return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: # emit `defaultdict(default_factory, new_dict)` codegen.add_push_null( lambda: codegen.extend_output( @@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: + # pyrefly: ignore[bad-assignment] items = dict.fromkeys(items, SetVariable._default_value()) + # pyrefly: ignore[bad-argument-type] super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "set()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return set(self.items.keys()) @staticmethod - def _default_value(): + def _default_value() -> VariableTracker: # Variable to fill in he keys of the dictionary return ConstantVariable.create(None) - def as_proxy(self): + def as_proxy(self) -> Any: return {k.vt.as_proxy() for k in self.set_items} - def python_type(self): + def python_type(self) -> type: return set - def as_python_constant(self): + def as_python_constant(self) -> Any: return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) - def _fast_set_method(self, tx, fn, args, kwargs): + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: try: res = fn( *[x.as_python_constant() for x in [self, *args]], @@ -1037,15 +1067,16 @@ def _fast_set_method(self, tx, fn, args, kwargs): raise_observed_exception( type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) ) + # pyrefly: ignore[unbound-name] return VariableTracker.build(tx, res) def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: # We forward the calls to the dictionary model from ..utils import check_constant_args @@ -1065,10 +1096,10 @@ def call_method( return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) if name == "__init__": - temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) tx.output.side_effects.mutation(self) self.items.clear() - self.items.update(temp_set_vt.items) + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "add": if kwargs or len(args) != 1: @@ -1079,7 +1110,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) name = "__setitem__" - args = (args[0], SetVariable._default_value()) + args = [args[0], SetVariable._default_value()] elif name == "pop": if kwargs or args: raise_args_mismatch( @@ -1090,12 +1121,14 @@ def call_method( ) # Choose an item at random and pop it via the Dict.pop method try: - result = self.set_items.pop().vt + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] except KeyError as e: raise_observed_exception( KeyError, tx, args=list(map(ConstantVariable.create, e.args)) ) - super().call_method(tx, name, (result,), kwargs) + # pyrefly: ignore[unbound-name] + super().call_method(tx, name, [result], kwargs) + # pyrefly: ignore[unbound-name] return result elif name == "isdisjoint": if kwargs or len(args) != 1: @@ -1217,6 +1250,7 @@ def call_method( f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" ) raise_observed_exception(TypeError, tx, args=[msg]) + assert m is not None return self.call_method(tx, m, args, kwargs) elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): @@ -1230,29 +1264,34 @@ def call_method( "__ixor__": "symmetric_difference_update", "__isub__": "difference_update", }.get(name) + assert m is not None self.call_method(tx, m, args, kwargs) return self elif name == "__eq__": if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(False) r = self.call_method(tx, "symmetric_difference", args, kwargs) - return ConstantVariable.create(len(r.set_items) == 0) + return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] elif name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: raise RuntimeError("Illegal to getitem on a set") - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: super().install_dict_contains_guard(tx, args) @@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "frozenset()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return self.items.keys() - def python_type(self): + def python_type(self) -> type: return frozenset - def as_python_constant(self): + def as_python_constant(self) -> Any: return frozenset({k.vt.as_python_constant() for k in self.set_items}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -1293,11 +1332,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a frozenset") elif name == "__init__": @@ -1316,7 +1355,7 @@ def call_method( "symmetric_difference", ): r = super().call_method(tx, name, args, kwargs) - return FrozensetVariable(r.items) + return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) @@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "dict_keys([])" else: @@ -1338,33 +1377,35 @@ def debug_repr(self): + "])" ) - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Already EQUALS_MATCH guarded pass @property - def set_items(self): + def set_items(self) -> Any: return self.items - def python_type(self): + def python_type(self) -> type: return dict_keys - def as_python_constant(self): + def as_python_constant(self) -> Any: return dict.fromkeys( {k.vt.as_python_constant() for k in self.set_items}, None ).keys() def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a dict_keys") return super().call_method(tx, name, args, kwargs) @@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker): kv: Optional[str] = None - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @property - def view_items(self): + def view_items(self) -> Any: + assert self.kv is not None return getattr(self.dv_dict.items, self.kv)() @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items # Implement in the subclasses raise NotImplementedError - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.view_items_vt - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.kv is not None codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": @@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable): kv = "keys" @property - def set_items(self): + def set_items(self) -> set[VariableTracker]: return set(self.view_items) @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [x.vt for x in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_keys def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name in ( @@ -1460,13 +1506,13 @@ def call_method( ): # These methods always returns a set m = getattr(self.set_items, name) - r = m(args[0].set_items) + r = m(args[0].set_items) # type: ignore[attr-defined] return SetVariable(r) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) @@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable): kv = "values" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: return list(self.view_items) - def python_type(self): + def python_type(self) -> type: return dict_values @@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable): kv = "items" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_items - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # TODO(guilhermeleobas): This should actually check if args[0] # implements the mapping protocol. if name == "__eq__": From 6c5db82584bf71f5b1db3b598bbd00f44140c28d Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:27:23 +0000 Subject: [PATCH 077/651] [Inductor] Naive foreach autotune support (#162053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code. Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 | triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 | triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 | After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 | triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 | triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 | num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053 Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily Co-authored-by: Nichols A. Romero --- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 3e58e95ef9e9c..1f531a5d99ef5 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -627,7 +627,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index cb43d55bc86b3..cdecd50927024 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3586,13 +3586,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, From fbd70fb84e347b45db79eb24cc2c53e447a04147 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Wed, 5 Nov 2025 19:35:34 +0000 Subject: [PATCH 078/651] Update typing docs to reference pyrefly (#166883) Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883 Approved by: https://github.com/malfet --- CONTRIBUTING.md | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9df55ca6acd5c..bc0b0fc9bb00f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,7 +18,7 @@ aspects of contributing to PyTorch. - [Python Unit Testing](#python-unit-testing) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) - [Local linting](#local-linting) - - [Running `mypy`](#running-mypy) + - [Running `pyrefly`](#running-pyrefly) - [C++ Unit Testing](#c-unit-testing) - [Run Specific CI Jobs](#run-specific-ci-jobs) - [Merging your Change](#merging-your-change) @@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory. **Prerequisites**: The following packages should be installed with `pip`: - `expecttest` and `hypothesis` - required to run tests -- `mypy` - recommended for linting +- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/) - `pytest` - recommended to run tests more selectively Running ``` @@ -350,15 +350,32 @@ make lint Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner) -#### Running `mypy` +#### Running `pyrefly` -`mypy` is an optional static type checker for Python. We have multiple `mypy` -configs for the PyTorch codebase that are automatically validated against whenever the linter is run. +[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback. + +PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository. + +**Getting Started with Pyrefly:** + +To run type checking on the PyTorch codebase: +```bash +pyrefly check +``` + +For more detailed error information with summaries: +```bash +pyrefly check --summarize-errors +``` + +**Learn More:** +- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options +- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking +- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations See [Guide for adding type annotations to PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch) -for more information on how to set up `mypy` and tackle type annotation -tasks. +for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase. ### C++ Unit Testing From 8e8cbb85ee927776210f7872e3d0286d5d40dc14 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 19:42:39 +0000 Subject: [PATCH 079/651] Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)" This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23. Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see https://hud.pytorch.org/hud/pytorch/pytorch/fbd70fb84e347b45db79eb24cc2c53e447a04147/1?per_page=50&name_filter=trunk%20%2F%20linux-jammy-cuda12&mergeEphemeralLF=true and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038)) --- test/inductor/test_torchinductor.py | 14 -------------- torch/_inductor/codecache.py | 6 ------ torch/_inductor/codegen/common.py | 11 ++--------- torch/_inductor/codegen/triton_utils.py | 5 ----- 4 files changed, 2 insertions(+), 34 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d0ff5799ac417..ed8993a1c9a39 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14424,20 +14424,6 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) - @skip_if_halide - @requires_cuda_and_triton - def test_unbacked_float_item(self): - def fn(x, max_val): - return torch.clamp(x, 0, max_val.item()) - - self.common( - fn, - ( - torch.randn(10, 20, 30, device=self.device), - torch.tensor(5.0, device=self.device), - ), - ) - # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 85702057cbb43..cf17bf2e9478b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2970,12 +2970,6 @@ class CppPythonBindingsCodeCache(CppCodeCache): throw std::runtime_error("expected int arg"); return reinterpret_cast(result); }} - template <> inline float parse_arg(PyObject* args, size_t n) {{ - auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); - if(unlikely(result == -1.0 && PyErr_Occurred())) - throw std::runtime_error("expected float arg"); - return static_cast(result); - }} {extra_parse_arg} diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3e9f174c810c5..730c03f1c813c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1732,15 +1732,9 @@ def cpp_argdefs( call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): - if isinstance(outer, sympy.Symbol) and symbol_is_type( - outer, (SymT.UNBACKED_FLOAT) - ): - arg_defs.append(f"const float {inner}") - arg_types.append("const float") - else: - arg_defs.append(f"const {INDEX_TYPE} {inner}") - arg_types.append(f"const {INDEX_TYPE}") + arg_defs.append(f"const {INDEX_TYPE} {inner}") call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert not self.workspace_args, "Workspace not supported on CPU " @@ -2359,7 +2353,6 @@ def rename_indexing( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, - SymT.UNBACKED_FLOAT, ), ) } diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 75a34813c876b..2a2706ad5720b 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -4,7 +4,6 @@ import sympy import torch -from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config from ..runtime.hints import AttrsDescriptorWrapper @@ -72,10 +71,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" - elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( - arg.expr, (SymT.UNBACKED_FLOAT) - ): - return "fp32" elif isinstance(arg.expr, bool): return "i1" From 6d30666bc1cad94295f708f74ebaf161e291c273 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 20:02:47 +0000 Subject: [PATCH 080/651] Revert "[12/N] Apply ruff UP035 rule (#166929)" This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52. Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596)) --- test/distributed/tensor/test_attention.py | 3 +-- test/higher_order_ops/test_local_map.py | 3 +-- test/inductor/test_caching.py | 3 +-- test/inductor/test_fx_fusion.py | 3 +-- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 +-- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 +-- torch/_dynamo/graph_bytecode_inputs.py | 3 +-- torch/_dynamo/variables/distributed.py | 3 +-- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 +-- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 +-- torch/_dynamo/variables/streams.py | 3 +-- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 +-- torch/_inductor/compile_worker/timer.py | 3 +-- torch/_inductor/fx_passes/bucketing.py | 3 +-- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +----- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 +-- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 +++-- torch/_inductor/kernel/custom_op.py | 3 +-- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 +-- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++---- torch/_inductor/runtime/caching/locks.py | 5 ++--- torch/distributed/elastic/multiprocessing/tail_log.py | 3 +-- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 +-- torch/utils/_pytree.py | 3 +-- 37 files changed, 50 insertions(+), 76 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 6c3485f9d7025..eaf3a4042060d 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,8 +3,7 @@ import itertools import random import unittest -from collections.abc import Callable -from typing import Any, ClassVar, Optional +from typing import Any, Callable, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index fbb21633260e7..9d2870d3b5fdd 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,9 +4,8 @@ import functools import unittest -from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Optional +from typing import Any, Callable, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index aa4c3a1f229f1..bcb66beea700c 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, TYPE_CHECKING, Union +from typing import Any, Generator, Sequence, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,7 +37,6 @@ if TYPE_CHECKING: - from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index 63342502d3cd9..ebe98373e622a 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,6 +1,5 @@ # Owner(s): ["module: inductor"] -from collections.abc import Callable -from typing import Any +from typing import Any, Callable import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index c37f844e41eae..1870a0e373be0 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from collections.abc import Callable +from typing import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index faba2f5edc6a7..cd922d94c60c3 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,8 +204,7 @@ import operator import unittest import io -from typing import Optional -from collections.abc import Callable +from typing import Callable, Optional class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 10611d4f24673..002c34c450756 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from collections.abc import Callable +from typing import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 222647eeae9ab..e23e049e3bbb1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,11 +39,10 @@ import unittest import warnings import weakref -from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 16583b89201ec..979950cf3bd1b 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,6 +1,5 @@ import weakref -from collections.abc import Callable -from typing import Any +from typing import Any, Callable from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 187055c26cd00..eb39dd8fa3e07 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,8 +20,7 @@ import functools import inspect -from collections.abc import Sequence -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index be765cbbc8bf9..5970ba0e1dda7 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable, Sequence -from typing import Any, TYPE_CHECKING, Union +from collections.abc import Callable +from typing import Any, Sequence, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index c09cc2163a5f4..289cebbe8129b 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,8 +22,7 @@ import logging import weakref -from collections.abc import Iterable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 644c269a23a34..85977104977fb 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable, Iterable -from typing import Any, TYPE_CHECKING, TypeVar +from collections.abc import Callable +from typing import Any, Iterable, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 629bf094dc951..75928842cf297 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,6 +1,5 @@ -from collections.abc import Sequence from inspect import getattr_static -from typing import Any, TYPE_CHECKING, TypeGuard +from typing import Any, Sequence, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fb5dd775bd636..c353181eb8029 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,6 +1,5 @@ import collections -from collections.abc import Callable -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4d0f0b4fae8ab..fa8412146a427 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Sequence from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index 7e608933b34c3..ce01e37f03243 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,10 +22,9 @@ import json import logging from abc import ABC, abstractmethod -from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7c495403b3a55..7cfeb4217e26b 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,7 +1,6 @@ -from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Optional, Union +from typing import Callable, Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 29f070564349c..ab831c96c94ba 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,8 +2,7 @@ import logging import operator from collections import defaultdict -from collections.abc import Callable -from typing import Any, Literal, TypeAlias +from typing import Any, Callable, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 44314b912786f..8a4de1a604869 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Callable, Generator +from collections.abc import Generator from dataclasses import dataclass from functools import partial -from typing import Any, cast +from typing import Any, Callable, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 1e71c350ed7b6..6b0c2ad2c94a7 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from typing import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index e887d4bf62c8e..c6b7c51b948e5 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict -from collections.abc import Callable from dataclasses import dataclass +from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 214d3bf02f7f4..70b3a3c355dde 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, TYPE_CHECKING +from typing import Any, Callable import torch from torch._dynamo.utils import counters @@ -35,10 +35,6 @@ ) -if TYPE_CHECKING: - from collections.abc import Callable - - if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index f383ab63dc261..a47aa960e58c5 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Callable, Iterable +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any +from typing import Any, Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index b511403d4874c..30768fda9bb72 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Callable, Sequence -from typing import Any +from collections.abc import Sequence +from typing import Any, Callable import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 91b4e10bf7238..7d995adec04ef 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,8 +5,7 @@ import logging import operator from collections import Counter, defaultdict -from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, Callable, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index e42e8a1139770..52222f3da8344 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Callable, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 0bad4fa7cc635..92e1e6f375f44 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,8 +4,9 @@ import operator import os from collections import defaultdict -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias +from collections.abc import Sequence +from typing import Any, Callable +from typing_extensions import TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index d35309c01d07c..303110a561b5e 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,8 +2,7 @@ import functools import logging -from collections.abc import Callable -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 0d3721aa730a4..c100df84d5a73 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,9 +3,8 @@ import functools import importlib -from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d9d92e363879d..d592a8c8c00f9 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Concatenate, Optional, Union -from typing_extensions import ParamSpec, Self, TypeVar +from typing import Any, Optional, Union +from typing_extensions import Concatenate, ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 03d2957493679..0758e11134018 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, TYPE_CHECKING, TypeAlias -from typing_extensions import override +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import override, TypeAlias from filelock import FileLock @@ -21,8 +21,6 @@ if TYPE_CHECKING: - from collections.abc import Callable - from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index 8e8cd011e2d44..e7e1f1adc3622 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import TYPE_CHECKING, TypeAlias -from typing_extensions import Protocol +from typing import Generator, TYPE_CHECKING +from typing_extensions import Protocol, TypeAlias from filelock import FileLock, Timeout @@ -21,7 +21,6 @@ if TYPE_CHECKING: - from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 034740810dcdd..7ad35115cd34a 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,10 +10,9 @@ import logging import os import time -from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Optional, TextIO, TYPE_CHECKING, Union +from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 897279bd39b1e..603625ed97c12 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeAlias, TypeVar, Union -from typing_extensions import deprecated, Self, TypeIs +from typing import Any, Optional, overload, TypeVar, Union +from typing_extensions import deprecated, Self, TypeAlias, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5a6ee246abf7e..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,8 +3,7 @@ import functools import traceback import weakref -from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 147340f58d66e..56704bb3f8024 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,11 +36,10 @@ Optional, overload, Protocol, - TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self +from typing_extensions import deprecated, NamedTuple, Self, TypeAlias from torch.torch_version import TorchVersion as _TorchVersion From a74fe75c450277eb88a95c764e8b0a664a550a86 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Nov 2025 08:21:40 -0800 Subject: [PATCH 081/651] Don't hardcode double argument for reduction base (#166951) Fixes https://github.com/pytorch/pytorch/issues/43254 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951 Approved by: https://github.com/ngimel, https://github.com/Skylion007 ghstack dependencies: #166813 --- aten/src/ATen/native/cpu/Reduce.h | 4 ++-- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 22 +------------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 6c9efbb0f6e7f..ab9051ca8d2a2 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { }); } -template -void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast(0)) { using traits = binary_function_traits; static_assert( all_same< diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 3bad49a32d98c..053db7b4eda00 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) { } } -template -struct MinValuesOps: public at::native::MinOps { - using arg_t = typename MinOps::arg_t; - static scalar_t project(arg_t arg) { - return arg.first; - } -}; - void min_values_kernel_impl(TensorIterator& iter) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - if (iter.dtype() == kLong || iter.dtype() == kUInt64) { - AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); - }), kLong, kUInt64); - return; - } AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, - static_cast(upper_bound())); + upper_bound()); }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } From ea44f12bce3eb05eaa9fa34943a3ffae04647fa5 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 20:51:47 +0000 Subject: [PATCH 082/651] [13/N] Apply ruff UP035 rule (#167048) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048 Approved by: https://github.com/Skylion007 --- test/dynamo/test_install_free_tensors.py | 4 ++-- test/dynamo/test_python_autograd.py | 6 +++++- test/typing/pass/arithmetic_ops.py | 4 ++-- torch/_C/_distributed_c10d.pyi | 3 ++- torch/_dynamo/variables/ctx_manager.py | 4 ++-- torch/_inductor/codegen/pallas.py | 4 +++- torch/_inductor/runtime/caching/config.py | 2 +- torch/distributed/_local_tensor/_c10d.py | 3 +-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index 3858b827bd598..fd9e14c4c3f76 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import unittest -from collections.abc import Sequence -from typing import Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import torch import torch._dynamo diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index a615c653f56c3..a6117bb4093a7 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple, Optional, TYPE_CHECKING import torch import torch._dynamo @@ -7,6 +7,10 @@ from torch._dynamo.testing import CompileCounter, same +if TYPE_CHECKING: + from collections.abc import Callable + + """ This is an example of a pure-python version of autograd implemented by @zdevito. It represents a rather challenging test case for TorchDynamo diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index f0d6cc6fd9f97..14dda1cf39772 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,5 +1,5 @@ -from typing import Union -from typing_extensions import assert_type, TypeAlias +from typing import TypeAlias, Union +from typing_extensions import assert_type from torch import randn, Tensor diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f3d96860f5584..b659be9ee119e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" +from collections.abc import Callable from datetime import timedelta from enum import Enum -from typing import Any, Callable, Optional, overload, Union +from typing import Any, Optional, overload, Union import torch from torch import Tensor diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 4eac189b65fdd..3f52c19ff0a90 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -21,9 +21,9 @@ import inspect import sys import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Sequence, Sized from contextlib import ExitStack -from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 1fc8e40724bc0..da437a4e8ee3c 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -2,7 +2,7 @@ from __future__ import annotations import hashlib -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy # noqa: TC002 @@ -17,6 +17,8 @@ if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from ..ir import IRNode from ..scheduler import BaseSchedulerNode diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index 748715d1631ad..14e13f937dbb7 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -1,6 +1,6 @@ import os +from collections.abc import Callable from functools import cache, partial -from typing import Callable import torch from torch._environment import is_fbcode diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index c9256543e8977..0b63330dfafce 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -1,9 +1,8 @@ import functools import math import operator -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import timedelta -from typing import Callable import torch from torch._C import ScriptObject From ef3f953966d94ce11ced06f8e468b2fa69c1b3cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 20:52:41 +0000 Subject: [PATCH 083/651] Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)" This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f. Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159)) --- .../tensor/debug/test_debug_mode.py | 22 ++- torch/utils/_debug_mode.py | 126 +++--------------- 2 files changed, 31 insertions(+), 117 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index abc37f17a74de..18cc702cbbc7a 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -50,24 +50,22 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode( - record_torchfunction=True, record_ids=True, record_output=True - ) as debug_mode: + with DebugMode(record_torchfunction=True) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) - aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) + torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t$2: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] - _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] - aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] - (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P - aten::sum(dt$6: f32[8, 32]| S(0)) - aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", + redistribute_input(t: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) + _c10d_functional::wait_tensor(t: f32[8, 32]) + aten::mm(t: f32[1, 8], t: f32[8, 32]) + (dt: f32[8, 32]| S(0)) + aten::sum(dt: f32[8, 32]| S(0)) + aten::sum(t: f32[1, 32])""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..09435aa07e68b 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,7 +2,6 @@ import contextlib import functools import traceback -import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -15,7 +14,6 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback -from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -58,48 +56,29 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -class TensorIdTracker: - def __init__(self): - self.tensor_memo: dict[WeakIdRef, int] = {} - self.next_tensor_id = 0 - - def _id(self, tensor) -> int: - with torch._C._DisablePythonDispatcher(): - o = WeakIdRef(tensor) - - def del_memo(): - self.tensor_memo.pop(o, None) - - weakref.finalize(tensor, del_memo) - if o not in self.tensor_memo: - self.tensor_memo[o] = self.next_tensor_id - self.next_tensor_id += 1 - return self.tensor_memo[o] - - -def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: +def _tensor_debug_string(tensor, attributes) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" + if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft{id_str}: {tensor_debug_str}" + return f"ft: {tensor_debug_str}" else: - return f"t{id_str}: {tensor_debug_str}" + return f"t: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes, tensor_memo=None) -> str: +def _arg_to_str(arg, attributes) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes, tensor_memo) + return _tensor_debug_string(x, attributes) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -165,11 +144,8 @@ def __init__( # results from dispatch hooks self.record = record self.log = log - self.output_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: + def stringify_args(self, attributes: list[str]) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -177,18 +153,6 @@ def stringify_args( "Subclasses must implement stringify_args(), even if no-op" ) - def stringify_output( - self, - output: Any, - attributes: list[str], - tensor_memo: Optional[TensorIdTracker] = None, - ) -> None: - """Store stringified version of call output in self.output_str""" - if tree_all(lambda x: x is None, output): - return - output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) - self.output_str = f" -> {str(output_str)}" - def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -215,16 +179,11 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: - self.args_str = ", ".join( - _arg_to_str(arg, attributes, tensor_memo) for arg in self.args - ) + def stringify_args(self, attributes: list[str]) -> None: + self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes, tensor_memo)}" - for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -256,8 +215,6 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" - if self.output_str: - base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -290,10 +247,8 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" + def stringify_args(self, attributes: list[str]) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -308,11 +263,7 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - - base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" - if self.output_str: - base_str += self.output_str - return base_str + return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -337,9 +288,7 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: + def stringify_args(self, attributes: list[str]) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -392,8 +341,6 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, - record_output=False, - record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -431,24 +378,8 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace - # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) - self.record_output: bool = record_output - - # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. - self.record_ids: bool = record_ids - - self.reset() - - def reset(self): self.operators = [] self.call_depth = 0 - self._tensor_memo = TensorIdTracker() - self._output_info: dict[int, object] = {} - - def _track_op_output(self, op_index, result): - """Assign IDs to output tensors and store in output_info""" - # self._track_tensor_ids(result) - self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -459,35 +390,20 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args( - self.record_tensor_attributes, - self._tensor_memo if self.record_ids else None, - ) + call.stringify_args(self.record_tensor_attributes) self.operators.append(call) - def _record_call_output(self, call, output): - if not self.record_output: - return - call.stringify_output( - output, - self.record_tensor_attributes, - self._tensor_memo if self.record_ids else None, - ) - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - call = _OpCall( - func, args, kwargs, self.call_depth, stack=self.record_stack_trace + self._record_call( + _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) ) - self._record_call(call) try: self.call_depth += 1 - result = func(*args, **kwargs) - self._record_call_output(call, result) - return result + return func(*args, **kwargs) finally: self.call_depth -= 1 @@ -529,13 +445,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: - self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.reset() + self.operators = [] + self.call_depth = 0 if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From c6c913d18e8c40ade1523cc0dd08f095217a2fdf Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:02 -0800 Subject: [PATCH 084/651] Add torch::stable::Tensor sizes and strides (#165153) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991, #165152 --- .../libtorch_agnostic/csrc/kernel.cpp | 18 ++++-------- torch/csrc/stable/tensor_struct.h | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 87aaa46e64c95..7154322641c32 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -47,20 +47,10 @@ Tensor sgd_out_of_place( STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); - int64_t *param_sizes; - int64_t *param_strides; - aoti_torch_get_sizes(param.get(), ¶m_sizes); - aoti_torch_get_strides(param.get(), ¶m_strides); + // testing Tensor strides + stride + STD_TORCH_CHECK(param.strides()[0] == param.stride(0)); - int32_t param_dtype; - aoti_torch_get_dtype(param.get(), ¶m_dtype); - - int32_t param_device_type; - aoti_torch_get_device_type(param.get(), ¶m_device_type); - - AtenTensorHandle out_ath; - aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); - auto out = Tensor(out_ath); + auto out = new_empty(param, param.sizes()); sgd_math( reinterpret_cast(param.data_ptr()), @@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) { // Still using a std::vector below even though people can just pass in an // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef) // directly. + // This is to test that passing in a std::vector works for BC. (It gets + // implicitly converted to HeaderOnlyArrayRef too!) std::vector sizes = {2, 5}; auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); diff --git a/torch/csrc/stable/tensor_struct.h b/torch/csrc/stable/tensor_struct.h index 88cc167e59770..0d44ffd075170 100644 --- a/torch/csrc/stable/tensor_struct.h +++ b/torch/csrc/stable/tensor_struct.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ HIDDEN_NAMESPACE_BEGIN(torch, stable) using accelerator::DeviceIndex; +using torch::headeronly::IntHeaderOnlyArrayRef; using torch::headeronly::ScalarType; // The torch::stable::Tensor class is a highlevel C++ wrapper around @@ -93,6 +95,32 @@ class Tensor { return numel; } + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the dimension sizes of + // a Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef sizes() const { + int64_t* sizes; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes)); + return IntHeaderOnlyArrayRef(sizes, dim()); + } + + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the strides of a + // Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef strides() const { + int64_t* strides; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides)); + return IntHeaderOnlyArrayRef(strides, dim()); + } + // note: this is a subset of the original TensorBase API. It takes no // arguments whereas the original API takes in a kwarg of memory format. // Here, we assume the default contiguous memory format. From 13d2cc7bd26e32cafff0377dda1c5ddc8d04c4ce Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 20:55:59 +0000 Subject: [PATCH 085/651] Remove python workaround for ContextDecorator (#167049) This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049 Approved by: https://github.com/Skylion007 --- torch/autograd/profiler.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index fa43af2701171..9e2a7b5046dee 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -52,26 +52,7 @@ "MemRecordsAcc", ] -try: - # Available in Python >= 3.2 - from contextlib import ContextDecorator as _ContextDecorator -except ImportError: - import functools - - class _ContextDecorator: # type: ignore[no-redef] - def __enter__(self): - raise NotImplementedError - - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError - - def __call__(self, func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapped +from contextlib import ContextDecorator # global python state - whether profiler is currently enabled @@ -744,8 +725,7 @@ def createFunctionEventForMemoryEvents(evt): return all_function_events -# pyrefly: ignore [invalid-inheritance] -class record_function(_ContextDecorator): +class record_function(ContextDecorator): """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. Label will only appear if CPU activity tracing is enabled. From fd8f368d31d622355275cfe0283ab582cd2ee903 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Nov 2025 17:45:11 -0800 Subject: [PATCH 086/651] [user-streams] Add graph annotation checks (#167019) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_graph_deduplication.py | 14 +- test/dynamo/test_streams.py | 230 ++++++++++++++++++++++-- torch/_dynamo/testing.py | 6 + 3 files changed, 225 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 004aee88a8633..fc9284a3c9542 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -8,21 +8,11 @@ from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.output_graph import FakeRootModule from torch._dynamo.test_case import TestCase -from torch._dynamo.testing import ( - AotEagerAndRecordGraphs, - extract_graph_and_tracker, - normalize_gm, -) +from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm from torch.compiler import allow_in_graph from torch.utils._ordered_set import OrderedSet -def extract_graph(fn, *args, **kwargs): - backend = AotEagerAndRecordGraphs() - result = torch.compile(backend=backend)(fn)(*args, **kwargs) - return result, backend.graphs, backend.fw_graphs - - def graph_str(gm): return normalize_gm(gm.print_readable(print_output=False)) @@ -40,7 +30,7 @@ def tearDown(self): super().tearDown() def run_and_return_graphs(self, fn, *args, **kwargs): - return extract_graph(fn, *args, **kwargs) + return extract_graph(fn, *args, **kwargs)[0:3] def run_and_get_simple_graph(self): def fn(x, y): diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index e05e1304d2860..0a49a21cca42b 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] import functools +import re import unittest import weakref import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.testing import extract_graph, remove_trailing_space from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda @@ -15,6 +17,14 @@ ) +def remove_file_comment(gm_str: str) -> str: + return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str)) + + +def print_graph(graph: torch.fx.GraphModule) -> str: + return remove_file_comment(graph.print_readable()) + + class TestStreams(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -36,9 +46,7 @@ def test_event_weakref(self): @requires_cuda def test_stream_enter_exit(self): - def fn(x, y): - s2 = torch.Stream() - s1 = torch.Stream() + def fn(x, y, s1, s2): with s1: z1 = torch.add(x, y) with s2: @@ -47,13 +55,36 @@ def fn(x, y): return y - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream()) expected = fn(*inp) - fn_opt = torch.compile(fn, fullgraph=True) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) @requires_cuda + @unittest.skip("Needs graph break support with annotation context") def test_stream_context_graph_break(self): def fn(x, y): s2 = torch.Stream() @@ -70,9 +101,16 @@ def fn(x, y): inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) expected = fn(*inp) - fn_opt = torch.compile(fn) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) self.assertEqual(expected, actual) + self.assertEqual(len(fw_graphs), 2) + self.assertExpectedInline(print_graph(fw_graphs[0]), """""") + self.assertExpectedInline(print_graph(fw_graphs[1]), """""") @requires_cuda def test_stream_input(self): @@ -155,22 +193,188 @@ def fn(x, s0, s1): self.assertEqual(s_act, s_exp) def test_nested_stream_enter_exit(self): - pass - + def fn(x, y, s0, s1, s2): + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = ( + torch.ones(2, 2) + 1, + torch.ones(2, 2), + torch.Stream(), + torch.Stream(), + torch.Stream(), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) + + @unittest.skip("Needs graph break support with annotation context") def test_stream_enter_exit_graph_break(self): pass + @unittest.skip("Needs graph break support with annotation context") def test_nested_stream_enter_exit_graph_break(self): pass def test_local_stream_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + with s1: + z1 = torch.add(x, y) + with s2: + z = torch.add(x, y) + y = z + 2 + z1 + + return y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 1} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) def test_local_stream_nested_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) def test_stream_with_mutation(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + x.add_(y) + with s0: + z1 = torch.add(y, y) + z0 = torch.add(z1, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1) + + # Annotation: {'stream': 2} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None + + # Annotation: {'stream': 0} + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + + # + copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None + return (add_2, add_3) +""", + ) @requires_cuda def test_run_opcheck(self): diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9206f2598afc2..3eeedfb65da20 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -87,6 +87,12 @@ def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def] return gm.graph, region_tracker # type: ignore[union-attr] +def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def] + backend = AotEagerAndRecordGraphs() + result = torch.compile(backend=backend)(fn)(*args, **kwargs) + return result, backend.graphs, backend.fw_graphs, backend.bw_graphs + + def collect_results( model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any ) -> list[Any]: From e69aaaf45a8018004aa91d58bef77199acbb888e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Nov 2025 17:45:12 -0800 Subject: [PATCH 087/651] [user-streams] Add backward test (#167021) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167019 --- test/dynamo/test_streams.py | 60 +++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 0a49a21cca42b..b9a3855f6ddbb 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -376,6 +376,66 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): """, ) + def test_stream_backward(self) -> None: + def fn(x, y): + s2 = torch.Stream() + s0 = torch.Stream() + with s0: + y0 = 2 * x + y + with s2: + z = 2 * x + y + + return y0, z + + inp = ( + torch.ones(2, 2, requires_grad=True) + 1, + torch.ones(2, 2, requires_grad=True), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): + # Annotation: {'stream': 1} + mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None + return (add, add_1) +""", + ) + + actual[1].sum().backward() + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): + # Annotation: {'stream': 0} + mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2) + + # + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None + + # Annotation: {'stream': 1} + mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + + # + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + return (add_3, add_2) +""", + ) + @requires_cuda def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream From e9a688f02ee742af2c1e24d7b2109beced35465f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 5 Nov 2025 22:00:11 +0000 Subject: [PATCH 088/651] [DebugMode] output, tensor id annotations for DebugMode (#165076) Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)` Example output for `test_debug_mode_mm`, with both enabled: ``` torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0) aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) redistribute_input(t$4: f32[1, 32], trace: S(0)->R) _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32] _c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32] aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32] (dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P aten::sum(dt$14: f32[8, 32]| S(0)) aten::sum(t$15: f32[1, 32]) -> t$16: f32[]""" ``` Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076 Approved by: https://github.com/zpcore --- .../tensor/debug/test_debug_mode.py | 22 +-- torch/utils/_debug_mode.py | 126 +++++++++++++++--- 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 18cc702cbbc7a..abc37f17a74de 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -50,22 +50,24 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_ids=True, record_output=True + ) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) - aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) + aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) - _c10d_functional::wait_tensor(t: f32[8, 32]) - aten::mm(t: f32[1, 8], t: f32[8, 32]) - (dt: f32[8, 32]| S(0)) - aten::sum(dt: f32[8, 32]| S(0)) - aten::sum(t: f32[1, 32])""", + redistribute_input(t$2: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] + _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] + aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + aten::sum(dt$6: f32[8, 32]| S(0)) + aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 09435aa07e68b..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,6 +2,7 @@ import contextlib import functools import traceback +import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -14,6 +15,7 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -56,29 +58,48 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -def _tensor_debug_string(tensor, attributes) -> str: +class TensorIdTracker: + def __init__(self): + self.tensor_memo: dict[WeakIdRef, int] = {} + self.next_tensor_id = 0 + + def _id(self, tensor) -> int: + with torch._C._DisablePythonDispatcher(): + o = WeakIdRef(tensor) + + def del_memo(): + self.tensor_memo.pop(o, None) + + weakref.finalize(tensor, del_memo) + if o not in self.tensor_memo: + self.tensor_memo[o] = self.next_tensor_id + self.next_tensor_id += 1 + return self.tensor_memo[o] + + +def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - + id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft: {tensor_debug_str}" + return f"ft{id_str}: {tensor_debug_str}" else: - return f"t: {tensor_debug_str}" + return f"t{id_str}: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes) -> str: +def _arg_to_str(arg, attributes, tensor_memo=None) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes) + return _tensor_debug_string(x, attributes, tensor_memo) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -144,8 +165,11 @@ def __init__( # results from dispatch hooks self.record = record self.log = log + self.output_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -153,6 +177,18 @@ def stringify_args(self, attributes: list[str]) -> None: "Subclasses must implement stringify_args(), even if no-op" ) + def stringify_output( + self, + output: Any, + attributes: list[str], + tensor_memo: Optional[TensorIdTracker] = None, + ) -> None: + """Store stringified version of call output in self.output_str""" + if tree_all(lambda x: x is None, output): + return + output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) + self.output_str = f" -> {str(output_str)}" + def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -179,11 +215,16 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.args_str = ", ".join( + _arg_to_str(arg, attributes, tensor_memo) for arg in self.args + ) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes, tensor_memo)}" + for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -215,6 +256,8 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" + if self.output_str: + base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -247,8 +290,10 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -263,7 +308,11 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + + base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + if self.output_str: + base_str += self.output_str + return base_str def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -288,7 +337,9 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -341,6 +392,8 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, + record_output=False, + record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -378,8 +431,24 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace + # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) + self.record_output: bool = record_output + + # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. + self.record_ids: bool = record_ids + + self.reset() + + def reset(self): self.operators = [] self.call_depth = 0 + self._tensor_memo = TensorIdTracker() + self._output_info: dict[int, object] = {} + + def _track_op_output(self, op_index, result): + """Assign IDs to output tensors and store in output_info""" + # self._track_tensor_ids(result) + self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -390,20 +459,35 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args(self.record_tensor_attributes) + call.stringify_args( + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) self.operators.append(call) + def _record_call_output(self, call, output): + if not self.record_output: + return + call.stringify_output( + output, + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self._record_call( - _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace ) + self._record_call(call) try: self.call_depth += 1 - return func(*args, **kwargs) + result = func(*args, **kwargs) + self._record_call_output(call, result) + return result finally: self.call_depth -= 1 @@ -445,13 +529,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: + self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.operators = [] - self.call_depth = 0 + self.reset() if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From 711a7758788ccdbb85bc20e9dd8146f5a7bafb24 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 08:54:22 -0800 Subject: [PATCH 089/651] fix nccl estimations (#167093) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093 Approved by: https://github.com/kwen2501, https://github.com/eellison --- torch/_inductor/comm_analysis.py | 2 +- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 61af576772c16..74a58acb84ff3 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. - use_nccl_estimator: bool = False, + use_nccl_estimator: bool = True, ) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d051803aa7376..3416bc336d34a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -3593,6 +3593,7 @@ float ProcessGroupNCCL::endTimeEstimate() { #ifdef NCCL_SIM_INFO_INITIALIZER ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER; C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt); + --ncclActiveGroupCounter_; return simInfo.estimatedTime; #else TORCH_CHECK( From ad7a57262c8f3ce6a2d724af533f09437495100f Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 22:06:19 +0000 Subject: [PATCH 090/651] [12/N] Apply ruff UP035 rule (#166929) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929 Approved by: https://github.com/Lucaskabela --- test/distributed/tensor/test_attention.py | 3 ++- test/higher_order_ops/test_local_map.py | 3 ++- test/inductor/test_caching.py | 3 ++- test/inductor/test_fx_fusion.py | 3 ++- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 ++- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 ++- torch/_dynamo/graph_bytecode_inputs.py | 3 ++- torch/_dynamo/variables/distributed.py | 3 ++- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 ++- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 ++- torch/_dynamo/variables/streams.py | 3 ++- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 ++- torch/_inductor/compile_worker/timer.py | 3 ++- torch/_inductor/fx_passes/bucketing.py | 3 ++- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +++++- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 ++- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 ++--- torch/_inductor/kernel/custom_op.py | 3 ++- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 ++- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++++-- torch/_inductor/runtime/caching/locks.py | 5 +++-- torch/distributed/elastic/multiprocessing/tail_log.py | 3 ++- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 ++- torch/utils/_pytree.py | 3 ++- 37 files changed, 76 insertions(+), 50 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index eaf3a4042060d..6c3485f9d7025 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,7 +3,8 @@ import itertools import random import unittest -from typing import Any, Callable, ClassVar, Optional +from collections.abc import Callable +from typing import Any, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 9d2870d3b5fdd..fbb21633260e7 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,8 +4,9 @@ import functools import unittest +from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index bcb66beea700c..aa4c3a1f229f1 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, Generator, Sequence, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index ebe98373e622a..63342502d3cd9 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 1870a0e373be0..c37f844e41eae 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from typing import Callable +from collections.abc import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cd922d94c60c3..faba2f5edc6a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,7 +204,8 @@ import operator import unittest import io -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 002c34c450756..10611d4f24673 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from typing import Callable +from collections.abc import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e23e049e3bbb1..222647eeae9ab 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,10 +39,11 @@ import unittest import warnings import weakref +from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 979950cf3bd1b..16583b89201ec 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,5 +1,6 @@ import weakref -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index eb39dd8fa3e07..187055c26cd00 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,7 +20,8 @@ import functools import inspect -from typing import Any, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 5970ba0e1dda7..be765cbbc8bf9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable -from typing import Any, Sequence, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 289cebbe8129b..c09cc2163a5f4 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,7 +22,8 @@ import logging import weakref -from typing import Any, Iterable, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 85977104977fb..644c269a23a34 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable -from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 75928842cf297..629bf094dc951 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from inspect import getattr_static -from typing import Any, Sequence, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index c353181eb8029..fb5dd775bd636 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,5 +1,6 @@ import collections -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index fa8412146a427..4d0f0b4fae8ab 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Sequence +from collections.abc import Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ce01e37f03243..7e608933b34c3 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,9 +22,10 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7cfeb4217e26b..7c495403b3a55 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Callable, Optional, Union +from typing import Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index ab831c96c94ba..29f070564349c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,7 +2,8 @@ import logging import operator from collections import defaultdict -from typing import Any, Callable, Literal, TypeAlias +from collections.abc import Callable +from typing import Any, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8a4de1a604869..44314b912786f 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6b0c2ad2c94a7..1e71c350ed7b6 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index c6b7c51b948e5..e887d4bf62c8e 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 70b3a3c355dde..214d3bf02f7f4 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, Callable +from typing import Any, TYPE_CHECKING import torch from torch._dynamo.utils import counters @@ -35,6 +35,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index a47aa960e58c5..f383ab63dc261 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 30768fda9bb72..b511403d4874c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7d995adec04ef..91b4e10bf7238 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,8 @@ import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 52222f3da8344..e42e8a1139770 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 92e1e6f375f44..0bad4fa7cc635 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,9 +4,8 @@ import operator import os from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 303110a561b5e..d35309c01d07c 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,7 +2,8 @@ import functools import logging -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index c100df84d5a73..0d3721aa730a4 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,8 +3,9 @@ import functools import importlib +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d592a8c8c00f9..d9d92e363879d 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Optional, Union -from typing_extensions import Concatenate, ParamSpec, Self, TypeVar +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 0758e11134018..03d2957493679 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import override, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override from filelock import FileLock @@ -21,6 +21,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index e7e1f1adc3622..8e8cd011e2d44 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import Generator, TYPE_CHECKING -from typing_extensions import Protocol, TypeAlias +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol from filelock import FileLock, Timeout @@ -21,6 +21,7 @@ if TYPE_CHECKING: + from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7ad35115cd34a..034740810dcdd 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,9 +10,10 @@ import logging import os import time +from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union +from typing import Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c12..897279bd39b1e 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..5a6ee246abf7e 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,7 +3,8 @@ import functools import traceback import weakref -from typing import Any, Callable, Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 56704bb3f8024..147340f58d66e 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,10 +36,11 @@ Optional, overload, Protocol, + TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion From 08200280ce3c7b5bfbf3997517254565b2d6f162 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 14:37:54 -0800 Subject: [PATCH 091/651] [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991) While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages. Add it to __all__ so that people can still use it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991 Approved by: https://github.com/XilunWu ghstack dependencies: #166456, #166501 --- torch/distributed/tensor/experimental/_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 2444467a3595f..f238739ddd5cf 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -10,6 +10,7 @@ _enable_context_parallel_dispatcher, _is_causal_behavior, _RotateMethod, + _templated_ring_attention, context_parallel, context_parallel_unshard, set_rotate_method, @@ -22,6 +23,7 @@ ) +# TODO(fegin): add deprecation message once the final interfaces are concluded. __all__ = [ "_CausalBehavior", "_context_parallel_shard", @@ -31,6 +33,7 @@ "_enable_context_parallel_dispatcher", "_is_causal_behavior", "_RotateMethod", + "_templated_ring_attention", "context_parallel", "context_parallel_unshard", "set_rotate_method", From 47eb34b7ac4359d281d1bfc3626feec184aec8b6 Mon Sep 17 00:00:00 2001 From: YyWangCS Date: Wed, 5 Nov 2025 22:34:16 +0000 Subject: [PATCH 092/651] [ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance (#167094) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Summary This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations. The key change: specialize `OpaqueType` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy. The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0. | GPU | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup | | ---- | ----------- | ----------- | ------------------- | ---------- | ------- | | H100 | (16, 1e6) | int32 | 1.61 | 1.37 | 1.18× | | H100 | (1, 1e8) | int32 | 6.6 | 5.0 | 1.3× | | H20 | (16, 1e6) | int64 | 3.57 | 3.03 | 1.18× | | H20 | (1, 1e8) | int64 | 19.3 | 13.0 | 1.48× | # Analysis `torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.` ## The Problem The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types: - `char data[8]:` Compiler may allocate 8 registers (one per byte) - `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling) This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy. From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%. ## The Solution Specialize `OpaqueType` for common sizes using native integer types: ``` // Before template struct alignas(N) OpaqueType { char data[N]; }; // After template struct alignas(N) OpaqueType { char data[N]; }; // fallback template <> struct alignas(1) OpaqueType<1> { uint8_t data; }; template <> struct alignas(2) OpaqueType<2> { uint16_t data; }; template <> struct alignas(4) OpaqueType<4> { uint32_t data; }; template <> struct alignas(8) OpaqueType<8> { uint64_t data; }; ``` This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation. # Testing & Compatibility ## Testing: ✅ Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3) ✅ Register usage reduction verified with NSight Compute ✅ Linter passes ## Compatibility: ✅ No API/ABI changes ✅ Template instantiation count unchanged # Reference For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094 Approved by: https://github.com/ngimel, https://github.com/Skylion007 --- aten/src/ATen/cuda/cub.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aten/src/ATen/cuda/cub.h b/aten/src/ATen/cuda/cub.h index 7430edaf8a3dc..bca9b1faff523 100644 --- a/aten/src/ATen/cuda/cub.h +++ b/aten/src/ATen/cuda/cub.h @@ -24,7 +24,13 @@ namespace detail { // radix_sort_pairs doesn't interact with value_t other than to copy // the data, so we can save template instantiations by reinterpreting // it as an opaque type. +// We use native integer types for 1/2/4/8-byte values to reduce +// register usage in CUDA kernels. For sizes > 8 fall back to char array. template struct alignas(N) OpaqueType { char data[N]; }; +template <> struct alignas(1) OpaqueType<1> { uint8_t data; }; +template <> struct alignas(2) OpaqueType<2> { uint16_t data; }; +template <> struct alignas(4) OpaqueType<4> { uint32_t data; }; +template <> struct alignas(8) OpaqueType<8> { uint64_t data; }; template void radix_sort_pairs_impl( From 3869aa115b1d513cb83ad89889f8c3af7921b0ce Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 5 Nov 2025 23:05:56 +0000 Subject: [PATCH 093/651] fix fr reset api (#166970) Summary: - there are various places that access fr's `entries_` field - if we empty the entries_ on reset, the accesses can result in an error - so we only perform a soft delete instead of clearing out the entries copletely - only reset id_ on the reset - keep track of a reset_epoch which increments everytime reset is called - dump_entries only returns entries from the latest epoch - api's that access entries also check if the reset epoch matches - make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970). * #166972 * #166971 * __->__ #166970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970 Approved by: https://github.com/fduwjj --- test/distributed/test_c10d_nccl.py | 223 ++++++++++++++++++ .../csrc/distributed/c10d/FlightRecorder.hpp | 42 +++- .../distributed/c10d/FlightRecorderDetail.hpp | 136 ++++++++--- .../distributed/c10d/ProcessGroupGloo.cpp | 7 +- .../distributed/c10d/ProcessGroupGloo.hpp | 1 + .../distributed/c10d/ProcessGroupNCCL.cpp | 18 +- .../distributed/c10d/ProcessGroupNCCL.hpp | 1 + 7 files changed, 389 insertions(+), 39 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index cf53896187c20..d764dfbbebbb1 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -5789,6 +5789,229 @@ def test_coalescing_manager_collective(self, timing_enabled): else: self.assertTrue("duration_ms" not in t["entries"][0]) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_circular_buffer_full(self, timing_enabled): + """ + Test that when the circular buffer in entries_ is full and we call reset, + then fill the buffer with new entries, dump_entries returns only the new + entries and not the old ones. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely with 10 entries + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify buffer is full with 10 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Now reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add new entries after reset - fill the buffer completely again + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we get exactly 10 new entries, not 20 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Verify all entries have the expected properties (from after reset) + # After reset, record IDs should start from 0 again + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertEqual(entry["profiling_name"], "nccl:all_reduce") + self.assertIn("record_id", entry) + # Record IDs should be sequential starting from 0 after reset + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_partial_overwrite(self, timing_enabled): + """ + Test that when the circular buffer is full, we reset, and then add fewer + entries than the buffer size, we only get the new entries. + This tests that old entries at the end of the circular buffer are properly + filtered out based on reset_epoch. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add only 3 new entries (much less than buffer size) + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we only get the 3 new entries, not 10 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 3) + + # Verify record IDs start from 0 after reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_wraparound(self, timing_enabled): + """ + Test that when we reset in the middle of the circular buffer and then + wrap around, dump_entries correctly returns only entries from the current + epoch in the correct order. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill half the buffer + for _ in range(5): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset at this point (reset happens at index 5) + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Now add 8 entries, which will wrap around + # (5->9 fills rest of buffer, then 0->2 wraps around) + for _ in range(8): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should get exactly 8 entries, properly ordered + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 8) + + # Entries should be in chronological order + # The dump_entries() method returns entries from next_ to end, then 0 to next_ + # After filtering old entries, we should have 8 entries in order + # Verify record IDs start from 0 after reset (id_ is reset in reset_all()) + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_multiple_resets(self, timing_enabled): + """ + Test multiple consecutive resets to ensure each reset properly increments + the epoch and filters out entries from previous epochs. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # First batch: 2 entries + for _ in range(2): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # First reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Second batch: 3 entries + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Second reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Third batch: 4 entries + for _ in range(4): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should only see the last 4 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 4) + + # Verify record IDs start from 0 after the last reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + def check_if_test_is_skipped(fn): def wrapper(self, *args, **kwargs): diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index 23b8893c54f2c..bdb4ad045ff2a 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -108,12 +108,14 @@ struct FlightRecorder { capture_cpp_stack_ = getCvarBool( {"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false); enabled_ = max_entries_ > 0; + reset_epoch_start_idx_[0] = 0; } struct Entry { size_t id_; // incremented id in the trace buffer // used to figure out where in the circular entries // buffer this entry will be located to // update state information + size_t reset_epoch_; // epoch when this entry was created size_t pg_id_; std::tuple pg_name_; // @@ -183,11 +185,34 @@ struct FlightRecorder { size_t max_entries_ = 0; size_t next_ = 0; size_t id_ = 0; + size_t reset_epoch_ = 0; + std::unordered_map + reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts std::map> all_pg_status_; std::map, std::vector> pg_name_to_ranks_; std::string comm_lib_version_; + struct TraceIdentifier { + std::optional id; + std::optional reset_epoch; + }; + + TraceIdentifier recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P); + std::optional record( size_t pg_id, const std::tuple& pg_name, @@ -213,8 +238,16 @@ struct FlightRecorder { std::vector dump_entries(); - // Returns the entry with the given id, if it exists. Otherwise, returns - // std::nullopt. + // Returns the index in entries_ for the given id and reset_epoch. + // Caller must hold mutex_lock before calling this method. + size_t getIdxFromId(size_t id, size_t reset_epoch) const; + + // Returns the entry with the given id and reset_epoch, if it exists. + // Otherwise, returns std::nullopt. + TORCH_API std::optional getEntry( + std::optional id, + std::optional reset_epoch); + TORCH_API std::optional getEntry(std::optional id); /* @@ -227,6 +260,11 @@ struct FlightRecorder { never hang. (timing must also be enabled for compute_duration - see TORCH_NCCL_ENABLE_TIMING). */ + TORCH_API void retire_id( + std::optional id, + std::optional reset_epoch, + bool compute_duration = true); + TORCH_API void retire_id( std::optional id, bool compute_duration = true); diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp index 8813c95158460..88205c171941c 100644 --- a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -53,8 +53,41 @@ std::optional FlightRecorder::record( std::chrono::milliseconds timeout_ms, std::shared_ptr pg_status, bool isP2P) { + auto result = recordWithResetEnabled( + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + inputs, + outputs, + start, + end, + timeout_ms, + std::move(pg_status), + isP2P); + return result.id; +} + +template +typename FlightRecorder::TraceIdentifier FlightRecorder:: + recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P) { if (!enabled_) { - return std::nullopt; + return TraceIdentifier{std::nullopt, std::nullopt}; } if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { // Current pg_status is not in FR. @@ -64,8 +97,13 @@ std::optional FlightRecorder::record( torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); std::lock_guard guard(mutex_); + TORCH_CHECK( + reset_epoch_start_idx_.find(reset_epoch_) != + reset_epoch_start_idx_.end()); + auto te = Entry{ id_, + reset_epoch_, pg_id, pg_name, collective_seq_id, @@ -104,15 +142,20 @@ std::optional FlightRecorder::record( te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); } + const auto next = next_++; + if (entries_.size() < max_entries_) { entries_.emplace_back(std::move(te)); } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } + entries_[next] = std::move(te); } - return id_++; + + if (next_ == max_entries_) { + next_ = 0; + } + + const auto id = id_++; + return TraceIdentifier{id, reset_epoch_}; } template @@ -163,15 +206,20 @@ std::vector::Entry> FlightRecorder< std::vector result; { std::lock_guard guard(mutex_); - result.reserve(entries_.size()); - result.insert( - result.end(), + // Filter entries during insertion - only keep entries from current epoch + auto filter = [this](const Entry& e) { + return e.reset_epoch_ == reset_epoch_; + }; + std::copy_if( entries_.begin() + static_cast(next_), - entries_.end()); - result.insert( - result.end(), + entries_.end(), + std::back_inserter(result), + filter); + std::copy_if( entries_.begin(), - entries_.begin() + static_cast(next_)); + entries_.begin() + static_cast(next_), + std::back_inserter(result), + filter); } // query any remaining events for (auto& r : result) { @@ -182,28 +230,47 @@ std::vector::Entry> FlightRecorder< } template -// Returns the entry with the given id, if it exists. Otherwise, returns -// std::nullopt. +// Returns the index in entries_ for the given id and reset_epoch. +// Caller must hold mutex_lock before calling this method. +size_t FlightRecorder::getIdxFromId(size_t id, size_t reset_epoch) + const { + // Look up the starting idx for the given reset epoch + auto it = reset_epoch_start_idx_.find(reset_epoch); + TORCH_CHECK(it != reset_epoch_start_idx_.end()); + // Calculate idx based on where the epoch started + return (it->second + id) % max_entries_; +} + +template +// Returns the entry with the given id and reset_epoch, if it exists. Otherwise, +// returns std::nullopt. std::optional::Entry> FlightRecorder< - EventType>::getEntry(std::optional id) { - if (!enabled_ || !id) { + EventType>:: + getEntry(std::optional id, std::optional reset_epoch) { + if (!enabled_ || !id || !reset_epoch) { return std::nullopt; } std::unique_lock guard(mutex_); - Entry entry = entries_.at(*id % max_entries_); - if (entry.id_ == *id) { + Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) { return entry; - } else { - return std::nullopt; } + return std::nullopt; +} + +template +std::optional::Entry> FlightRecorder< + EventType>::getEntry(std::optional id) { + return getEntry(id, 0); } template void FlightRecorder::retire_id( std::optional id, + std::optional reset_epoch, bool compute_duration) { - if (!enabled_ || !id) { + if (!enabled_ || !id || !reset_epoch) { return; } @@ -214,8 +281,8 @@ void FlightRecorder::retire_id( std::unique_lock guard(mutex_); - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { + Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) { update_state(*entry); if (compute_duration) { @@ -237,8 +304,8 @@ void FlightRecorder::retire_id( guard.lock(); // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { + entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) { LOG(INFO) << "retire_id abandoned for id " << *id << ", event was overwritten while waiting to compute duration."; return; @@ -249,12 +316,23 @@ void FlightRecorder::retire_id( } } +template +void FlightRecorder::retire_id( + std::optional id, + bool compute_duration) { + retire_id(id, 0, compute_duration); +} + template void FlightRecorder::reset_all() { std::lock_guard guard(mutex_); - next_ = 0; - id_ = 0; - entries_.clear(); + if (!entries_.empty()) { + // Soft delete: increment epoch to mark all existing entries as old + // Store where the new epoch starts in the circular buffer + reset_epoch_++; + reset_epoch_start_idx_[reset_epoch_] = next_; + id_ = 0; + } } template diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index a9612ce759733..c1d28b2787cda 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) { // TODO: We need to have numel of tensors for gloo as well. pgStatus_->lastCompletedNumelIn = 0; pgStatus_->lastCompletedNumelOut = 0; - FlightRecorder::get()->retire_id(work->trace_id_, false); + FlightRecorder::get()->retire_id( + work->trace_id_, work->trace_reset_epoch_, false); lock.lock(); workInProgress_[workerIndex].reset(); } @@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { pgStatus_->lastEnqueuedNumelOut = 0; // using c10d::FlightRecorder; // TODO: We need to have a way to use c10::Event inside gloo as well. - work->trace_id_ = FlightRecorder::get()->record( + auto traceId = FlightRecorder::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), collectiveCounter_, @@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { work->getTimeout(), pgStatus_, false); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; workQueue_.push_back(std::move(work)); lock.unlock(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index b2cc6993528bf..1a0b7c41b3857 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; std::shared_ptr context_; const std::chrono::milliseconds timeout_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3416bc336d34a..29ccc115cc94d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), + trace_reset_epoch_(w.trace_reset_epoch_), distDebugLevel_(w.distDebugLevel_) { exception_ = w.exception_; } @@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( // Print the traceback of the collective at call time std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const { // First step we get the corresponding record entry from FR, based on work's - // trace_id_ + // trace_id_ and trace_reset_epoch_ std::optional entry = - FlightRecorderCUDA::get()->getEntry(trace_id_); + FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_); if (entry.has_value()) { auto entryVal = entry.value(); // Get stack trace from FR entry, in string format @@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_; pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_; - FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); + FlightRecorderCUDA::get()->retire_id( + work.trace_id_, work.trace_reset_epoch_, true); if (pg_->onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook // thread @@ -3360,7 +3362,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( // these objects to the Work because it has implications for keeping those // tensors alive longer and adds overhead when copying Work objects // between threads - r->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -3374,6 +3376,8 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( options_->timeout, pgStatus_, isP2P); + r->trace_id_ = traceId.id; + r->trace_reset_epoch_ = traceId.reset_epoch; } return r; } @@ -3677,7 +3681,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // later in endCoalescing we record a 'coalesced' Work which has // timing/state updates via watchdog thread, but lacks op metadata such as // input/output sizes and profilingTitle per-op in the group. - FlightRecorderCUDA::get()->record( + FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4169,7 +4173,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // TODO(whc) because we don't pass output {tensor} to initWork, we tell // initWork to not record, and then we manually call record passing all the // information it wants. - work->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4183,6 +4187,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( options_->timeout, pgStatus_, /*isP2P=*/true); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; } // Only check for NaN for send ops, for recv ops `tensor` can be a random diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ead1a107394d..d8f324dbd8edf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; DebugLevel distDebugLevel_; friend class ProcessGroupNCCL; }; From af829c0dade306762213d25506a83da850e30a3c Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 5 Nov 2025 23:15:17 +0000 Subject: [PATCH 094/651] [ROCm] Skip nvfp4 tests on ROCm (#167066) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066 Approved by: https://github.com/jeffdaily, https://github.com/slayton58 --- test/test_scaled_matmul_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 9738ac4ac6fbf..fd09afc11cecf 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -1864,6 +1864,8 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: + if torch.version.hip and recipe == "nvfp4": + raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping") if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") From a344069f2aba6a87c0ab3fab488e83b80691457f Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 23:16:48 +0000 Subject: [PATCH 095/651] Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py (#166969) This PR adds missing skips for efficient attention tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969 Approved by: https://github.com/jeffdaily --- test/test_transformers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 56e1365d33c44..cc82cbff2a46f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1914,6 +1914,7 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): q, k, v, None, 0.0, is_causal=True)) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): batch_size = 2**16 query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) @@ -1935,6 +1936,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) @@ -1948,6 +1950,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): @largeTensorTest("15GB", "cuda") @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_large_seq_len_uniform_attention(self): device = torch.device("cuda") dtype = torch.bfloat16 From d29efba8fa83215465c5dd8914769593b69ed304 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 6 Nov 2025 00:34:40 +0000 Subject: [PATCH 096/651] Move almalinux docker image to DEVTOOLSET 13 (#167018) 1. Update general Almalinux image to Devtoolset 13. 2. Fix ROCm images, missing devtoolset-13 This image used by Linux Job in test-infra Pull Request resolved: https://github.com/pytorch/pytorch/pull/167018 Approved by: https://github.com/sudharssun, https://github.com/d4l3k --- .ci/docker/almalinux/Dockerfile | 25 +++++++++++++++++++++---- .ci/docker/almalinux/build.sh | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index ce7803cf9acd2..3bc3fd8badc6d 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 -ARG DEVTOOLSET_VERSION=11 +ARG DEVTOOLSET_VERSION=13 RUN yum -y update RUN yum -y install epel-release # install glibc-langpack-en make sure en_US.UTF-8 locale is available RUN yum -y install glibc-langpack-en -RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain +RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb # Just add everything as a safe.directory for git since these will be used in multiple places with git RUN git config --global --add safe.directory '*' ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH @@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh # Install CUDA FROM base as cuda ARG CUDA_VERSION=12.6 +ARG DEVTOOLSET_VERSION=13 RUN rm -rf /usr/local/cuda-* ADD ./common/install_cuda.sh install_cuda.sh COPY ./common/install_nccl.sh install_nccl.sh @@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION} # Preserve CUDA_VERSION for the builds ENV CUDA_VERSION=${CUDA_VERSION} # Make things in our path by default -ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH +ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + FROM cuda as cuda12.6 RUN bash ./install_cuda.sh 12.6 @@ -68,8 +70,22 @@ FROM cuda as cuda13.0 RUN bash ./install_cuda.sh 13.0 ENV DESIRED_CUDA=13.0 -FROM ${ROCM_IMAGE} as rocm +FROM ${ROCM_IMAGE} as rocm_base +ARG DEVTOOLSET_VERSION=13 +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 +# Install devtoolset on ROCm base image +RUN yum -y update && \ + yum -y install epel-release && \ + yum -y install glibc-langpack-en && \ + yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb +RUN git config --global --add safe.directory '*' +ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + +FROM rocm_base as rocm ARG PYTORCH_ROCM_ARCH +ARG DEVTOOLSET_VERSION=13 ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ADD ./common/install_mkl.sh install_mkl.sh RUN bash ./install_mkl.sh && rm install_mkl.sh @@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0 # Final step FROM ${BASE_TARGET} as final +ARG DEVTOOLSET_VERSION=13 COPY --from=openssl /opt/openssl /opt/openssl COPY --from=patchelf /patchelf /usr/local/bin/patchelf COPY --from=conda /opt/conda /opt/conda diff --git a/.ci/docker/almalinux/build.sh b/.ci/docker/almalinux/build.sh index ad234ce1ffb93..885c4440e0e6f 100755 --- a/.ci/docker/almalinux/build.sh +++ b/.ci/docker/almalinux/build.sh @@ -63,7 +63,7 @@ docker build \ --target final \ --progress plain \ --build-arg "BASE_TARGET=${BASE_TARGET}" \ - --build-arg "DEVTOOLSET_VERSION=11" \ + --build-arg "DEVTOOLSET_VERSION=13" \ ${EXTRA_BUILD_ARGS} \ -t ${tmp_tag} \ $@ \ From 6cd57e6fc275e8d53665aab4d4fbaa71e29eb9ea Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 6 Nov 2025 00:50:42 +0000 Subject: [PATCH 097/651] [cuBLAS] Force tensor-core-no-reduction algo in `cuBLASLt` for `n=1` cases (#166735) Ostensibly useful for batch-invariance purposes Pull Request resolved: https://github.com/pytorch/pytorch/pull/166735 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/CUDABlas.cpp | 51 ++++++++++++++++++++++++--------- test/test_matmul_cuda.py | 23 +++++++++++++++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index aaed431064611..20f235076220f 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D #ifndef USE_ROCM at::Half halpha; at::Half hbeta; + uint32_t mask = -1; #endif void * alpha_ptr = α void * beta_ptr = β @@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); if (fp16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = fp16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); if (bf16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = bf16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; - TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( - ltHandle, - computeDesc.descriptor(), - Adesc.descriptor(), - Bdesc.descriptor(), - Cdesc.descriptor(), - Cdesc.descriptor(), - preference.descriptor(), - 1, - &heuristicResult, - &returnedResult)); + // on Blackwell+, we fake a n > 1 matmul when querying heuristics + // to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance +#ifndef USE_ROCM + const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10; +#else + const bool lie_to_cublaslt = false; +#endif + if (lie_to_cublaslt) { + CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T); + CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc); + + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + FakeBdesc.descriptor(), + FakeCdesc.descriptor(), + FakeCdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } else { + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } if (returnedResult == 0) { cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; } diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 10611d4f24673..a8e9be4c972a1 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -359,6 +359,29 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(agrad, a.grad) self.assertEqual(bgrad, b.grad) + @onlyCUDA + @skipIfRocm + @dtypes(torch.half, torch.bfloat16) + @unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell") + @serialTest() + def test_cublas_batch_invariance_blackwell(self, device, dtype): + orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False) + with blas_library_context('cublaslt'): + N = 2048 + K = 6144 + M_max = 32 + x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16) + w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t() + full = x @ w + xx = x[:1] + out = xx @ w + self.assertEqual(full[:1], out, atol=0., rtol=0.) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) From 872d1daec2726e8915a4d38427fa0d1c938e5905 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 5 Nov 2025 12:50:02 -0800 Subject: [PATCH 098/651] Avoid DDE in narrow with unbacked start (#166361) Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste --- aten/src/ATen/native/TensorShape.cpp | 58 ++++++++++++++++++++++-- c10/core/SymBool.cpp | 14 ++++++ c10/core/SymBool.h | 6 +++ test/export/test_export.py | 31 ++++++++----- test/test_dynamic_shapes.py | 51 +++++++++++++++++++++ test/test_torchfuzz_repros.py | 28 ------------ torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 +++++++- torch/utils/_sympy/printers.py | 36 +++++++++++++++ 9 files changed, 200 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822db..daa8a86da253b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,37 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + + auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0)); + auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0)); + + if (cond1 || cond2) { + if (cond1) { + start = start + cur_size; + } + + TORCH_SYM_CHECK( + start.sym_le(cur_size - length), + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + return at::slice_symint(self, dim, start, start + length, 1); + } + + // Unbacked start handling! + + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1749,28 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) { + return at::slice_symint(self, dim, start, end, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d27409..48c407b8b069c 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b1d..a27a28a5bf8a3 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e55..cdc18b1d4c564 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d50a..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,57 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 3b864aae4f477..988bcf8de273c 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -257,34 +257,6 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163971") - def test_fuzzer_issue_163971(self): - torch.manual_seed(0) - - def foo(arg0): - t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda - t1 = torch.softmax( - t0, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t2 = torch.nn.functional.gelu( - t1 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t3 = torch.softmax( - t2, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - output = t3 - return output - - arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True) - - out_eager = foo(arg0) - out_eager.sum().backward() - print("Eager Success! ✅") - compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True) - out_compiled = compiled_foo(arg0) - out_compiled.sum().backward() - print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #164059") def test_fuzzer_issue_164059(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e629d9c7bdebd..947166cf216cd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,7 +2063,8 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000db..693d25aea6130 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3f8..915d0e5461f1e 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From fd5edda1edd3f2c7ad555a626351e359af164fb4 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 6 Nov 2025 01:14:25 +0000 Subject: [PATCH 099/651] Reland "Add model code stack trace to torch.profile (#166677)" (#167110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167110 Approved by: https://github.com/pianpwk --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 184 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 429 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index 92d35fd8f49ad..f728187fd85f5 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -72,9 +72,16 @@ IS_WINDOWS, run_tests, skipIfTorchDynamo, + skipIfRocm, ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4212,6 +4249,153 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..a61aee321fcff 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines: list[str] = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 899a50f0f4142..d924eac24d3c2 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -817,6 +818,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -826,8 +831,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,6 +1798,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1846,6 +1866,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1858,6 +1879,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1868,6 +1890,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From 7432676187178fcdb41a0685b078e97e436fc561 Mon Sep 17 00:00:00 2001 From: inventshah <39803835+inventshah@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:55:38 +0000 Subject: [PATCH 100/651] [MPS] Fix crash in BCELoss backwards with reduction="none" and inputs with trailing 1s in shape (#166786) Fixes #166746 by removing squeezes that caused shape mismatches when calling backwards through `BCELoss(reduction='none')`. Based on running these tests, it seems MPSGraph can handle inputs without squeezing. ``` python test/test_mps.py TestMPS -k test_bce python test/test_mps.py TestConsistency -k binary_cross ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166786 Approved by: https://github.com/malfet --- .../src/ATen/native/mps/operations/LossOps.mm | 19 +++++++------------ test/test_mps.py | 8 ++++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index c995b8fc237f3..f0bbcdabfa5cd 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -212,17 +212,12 @@ loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({})); TORCH_CHECK(loss.is_mps()); - Tensor loss_squeezed = loss.squeeze(); - Tensor input_squeezed = input.squeeze(); - Tensor target_squeezed = target.squeeze(); - @autoreleasepool { - std::string key = - op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight}); + std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed); + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); MPSGraphTensor* bceLossUnweighted = nil; // if grad_output is defined, then it's a backward pass @@ -252,12 +247,12 @@ newCachedGraph->gradInputTensor = bceLoss; } } else { - newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); + newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size()); } }); - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed); - Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); + Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss); NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; diff --git a/test/test_mps.py b/test/test_mps.py index fad09c2f5eb28..cb0db4d96d334 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4470,6 +4470,14 @@ def test_bce_loss_broadcasts_weights(self): self.assertEqual(out1, out2) + def test_bce_backward_with_no_reduction_and_one_in_shape(self): + # Regression test for https://github.com/pytorch/pytorch/issues/166746 + output = torch.zeros(3, 2, 1, requires_grad=True, device='mps') + target = torch.zeros(3, 2, 1, device='mps') + torch.sum(nn.BCELoss(reduction='none')(output, target)).backward() + expected_grad = torch.zeros(3, 2, 1, device='mps') + self.assertEqual(output.grad, expected_grad) + def test_cross_entropy_loss(self): # Regression test for https://github.com/pytorch/pytorch/issues/116095 loss = nn.CrossEntropyLoss() From 69af74972b30d748266323c8099be5743b4c0b72 Mon Sep 17 00:00:00 2001 From: Samuel Park Date: Thu, 6 Nov 2025 01:59:48 +0000 Subject: [PATCH 101/651] Bugfix to forward autodiff causing different datatype 2 (#165784) Fixes #160513 ## The Problem Summary The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/__init__.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number. ## The Fix The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the `is_wrapped_number` method but this came with a big issue. During the `forward_ad` the derivative of those scalars turned out to be `ZeroTensor`s. The `ZeroTensor` internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the `is_number_wrapped_` property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for `is_wrapped_number_` requires `dim > 0` and a scalar `ZeroTensor` is a meta dtype tensor which complicates things. So I chose the route of creating a new property called `was_wrapped_number` and exposed this property to the python tensor API. I had to modify the autograd code generation to set `was_wrapped_number` in the mul, add, and div operations in `VariableType.cpp`. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed. I wrote a new ops testing module `TestForwardADWithScalars`. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations. [edit]: Just used `efficientzerotensor` meta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved. The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in the `BinaryOps.cpp` to get appropriate dtype for resulting arithmetic. @ezyang @OihanJoyot Pull Request resolved: https://github.com/pytorch/pytorch/pull/165784 Approved by: https://github.com/ezyang --- aten/src/ATen/native/BinaryOps.cpp | 24 +++++++++++++--- test/test_ops.py | 38 +++++++++++++++++++++++++ tools/autograd/gen_variable_type.py | 13 +++++++++ torch/csrc/autograd/FunctionsManual.cpp | 6 ++++ torch/csrc/autograd/FunctionsManual.h | 1 + torch/csrc/jit/python/pybind_utils.cpp | 26 ++++++++++++----- 6 files changed, 97 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f5d5edb6439a6..2fa6bcc6dc9ac 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) { } } +static Tensor send_to_meta(const Tensor& self, const Device& device) { + Tensor out_meta; + if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) { + out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device)); + out_meta.unsafeGetTensorImpl()->set_wrapped_number(true); + } else { + out_meta = self.to(device); + } + return out_meta; +} + Tensor mul_zerotensor(const Tensor& self, const Tensor& other) { auto out_device = correct_out_device(self, other); // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta); return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device)); } @@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) { // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta); if (self._is_zerotensor()) { if (other._is_zerotensor()) { @@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::add_Tensor::redispatch( - meta_dks, self.to(device_), other.to(device_), alpha); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha); auto get_out_like = [&] (const Tensor& tensor) { diff --git a/test/test_ops.py b/test/test_ops.py index 165b284b76d5c..5f44a3ba0841b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2992,12 +2992,50 @@ def test_strided_layout(self, device, dtype, op): self.assertEqual(strided_result.layout, torch.strided) +class TestForwardADWithScalars(TestCase): + @ops( + [op for op in op_db if op.name in ["mul", "add", "div"]], + allowed_dtypes=(torch.float32,), + ) + def test_0d_tensor_with_python_scalar(self, device, dtype, op): + """Test that forward AD preserves dtype when combining 0D tensors with Python scalars.""" + if torch.float not in op.supported_backward_dtypes(device): + raise unittest.SkipTest("Does not support autograd") + + # skip if operator doesnt support forward AD + if not op.supports_forward_ad: + raise unittest.SkipTest("Does not support forward_ad") + + # create 0D tensors + primal0d = torch.ones((), device=device, dtype=dtype) + tangent0d = torch.ones((), device=device, dtype=dtype) + + with torch.autograd.forward_ad.dual_level(): + dual0d = torch.autograd.forward_ad.make_dual(primal0d, tangent0d) + + # Test with scalar on RHS + if op.supports_rhs_python_scalar: + result = op(dual0d, 2.0) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on RHS - dtype mismatch" + ) + # Test with scalar on LHS + if op.supports_one_python_scalar: + result = op(2.0, dual0d) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on LHS - dtype mismatch" + ) + + instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True) instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") instantiate_device_type_tests(TestFakeTensor, globals()) instantiate_device_type_tests(TestTags, globals()) +instantiate_device_type_tests(TestForwardADWithScalars, globals()) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 13ca3e1389ac1..4796153f24f05 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -763,6 +763,12 @@ """ ) +FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE = CodeTemplate( + """\ +update_wrapped_number(${inp_name}_tensor, ${inp_name}_t); +""" +) + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( """\ auto ${inp_name}_p = toNonOptPrimal(${inp}); @@ -1911,6 +1917,13 @@ def emit_fw_derivatives() -> list[str]: zeros_fn=zeros_fn, ) ) + if zeros_fn == "_efficientzerotensor_symint": + unpacked_arguments += ( + FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE.substitute( + inp_name=inp.name + ) + ) + if inp.name in (derivative.required_inputs_primal or []): unpacked_arguments += ( FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 42d701298b0d1..b3cb07ac1cf9f 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -79,6 +79,12 @@ Tensor toNonOptPrimal(const std::optional& t) { return Tensor(); } +void update_wrapped_number(Tensor& input, Tensor& output) { + if (input.unsafeGetTensorImpl()->is_wrapped_number()) { + output.unsafeGetTensorImpl()->set_wrapped_number(true); + } +} + void copy_range(variable_list& out, IndexRange range, const Tensor& t) { TORCH_CHECK(range.second <= out.size()); TORCH_CHECK( diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 4dc0425d426ec..ee0f919c44012 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -43,6 +43,7 @@ inline std::optional wrap_opt_if(const Tensor& t, const bool cond) { TORCH_API Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction); TORCH_API bool any_variable_defined(const variable_list& variables); +TORCH_API void update_wrapped_number(Tensor& input, Tensor& output); TORCH_API void copy_range( variable_list& out, IndexRange range, diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index d60a6a0990082..9f7c2756d0d73 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -587,7 +587,9 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isTensor()) { auto tensor = std::move(ivalue).toTensor(); if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { - TORCH_INTERNAL_ASSERT(tensor.device().is_cpu()); + TORCH_INTERNAL_ASSERT( + tensor.device().is_cpu() || + (tensor._is_zerotensor() && tensor.dim() == 0)); auto py_tensor = py::cast(tensor); if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) { return py_tensor.attr("_wrapped_number"); @@ -595,17 +597,27 @@ py::object toPyObject(IValue ivalue) { auto scalar_type = tensor.scalar_type(); switch (scalar_type) { case at::ScalarType::Bool: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(false) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(int64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::UInt64: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(uint64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Double: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(0.0) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: // TODO: https://github.com/pytorch/pytorch/issues/77134 - return py::cast(static_cast>( - *tensor.const_data_ptr>())); + return (tensor._is_zerotensor()) + ? py::cast(std::complex(0.0, 0.0)) + : py::cast(static_cast>( + *tensor.const_data_ptr>())); default: TORCH_CHECK( false, From 3a2d75a0869b3ef2344ab0501c19787924442c3e Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Thu, 6 Nov 2025 02:01:57 +0000 Subject: [PATCH 102/651] Change template 'Release highlight for proposed Feature'->'New Feature for Release' (#167145) Makes it simpler and more clear Pull Request resolved: https://github.com/pytorch/pytorch/pull/167145 Approved by: https://github.com/huydhn --- .github/ISSUE_TEMPLATE/release-feature-request.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/release-feature-request.yml b/.github/ISSUE_TEMPLATE/release-feature-request.yml index 80f10807ae56b..090a41d1942f6 100644 --- a/.github/ISSUE_TEMPLATE/release-feature-request.yml +++ b/.github/ISSUE_TEMPLATE/release-feature-request.yml @@ -1,11 +1,11 @@ -name: 🚀 Release highlight for proposed Feature +name: 🚀 New Feature for Release description: Submit a Release highlight for proposed Feature labels: ["release-feature-request"] body: - type: textarea attributes: - label: Release highlight for proposed Feature + label: New Feature for Release description: > Example: “A torch.special module, analogous to SciPy's special module.” - type: input From 943227f57bcd638ab288331442748769f907d8c1 Mon Sep 17 00:00:00 2001 From: "Junjie Wang (PyTorch)" Date: Thu, 6 Nov 2025 02:08:01 +0000 Subject: [PATCH 103/651] [c10d] Fix split_group bug by having the parent pg option deep copied (#167125) Summary: Inside group_split api, we share the reference of PG option with parent PG if a PG option is not explicitly specified. This is bad because if we split parent pg multiple times, we will run into errors. Test Plan: UT + internal test. Differential Revision: D86225394 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167125 Approved by: https://github.com/Skylion007 --- test/distributed/test_c10d_nccl.py | 8 ++++++++ torch/distributed/distributed_c10d.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d764dfbbebbb1..ef7ed5282816f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -6343,6 +6343,14 @@ def test_comm_recursive_split_group(self): if self.rank == 6 or self.rank == 7: dist.broadcast(tensor2, 6, group=ng2) self.assertEqual(tensor2, torch.full((1,), 6)) + + # Test the case when the split changes the pg option of split group + # while the parent pg option is not changed. + new_pg = c10d.new_group([0, 1, 2, 3, 4, 5, 6, 7], device_id=device) + backend_new_pg = new_pg._get_backend(torch.device(device)) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) + c10d.split_group(new_pg, [[0, 2, 4, 6], [1, 3, 5, 7]]) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) # a barrier and a cuda sync before destroying all pgs. dist.barrier(pg) torch.cuda.synchronize() diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9e4ec1483e960..415cbacc177a8 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3,6 +3,7 @@ import collections.abc import contextlib +import copy import ctypes import hashlib import io @@ -5212,7 +5213,9 @@ def split_group( if pg_options is None: # default pg_options same as the parent process group - pg_options = parent_backend.options + # A deep copy is needed because if the option will be modified inside split + # and if we split parent pg multiple times, we will run into device out of bound error. + pg_options = copy.deepcopy(parent_backend.options) # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, # which may just pass their timeout value (or None) From e1a1aeaf5b951e4eb9ce49756311e8f59cf29eb8 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 02:25:10 +0000 Subject: [PATCH 104/651] [1/N] Use `key in dict` for existence checks (#167035) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167035 Approved by: https://github.com/janeyx99 --- torch/_dynamo/backends/registry.py | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/_export/converter.py | 2 +- torch/_guards.py | 2 +- torch/_inductor/augmented_graph_helper.py | 2 +- torch/_inductor/bounds.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 4 ++-- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/cpu_vec_isa.py | 2 +- torch/_inductor/cudagraph_utils.py | 2 +- torch/_inductor/graph.py | 2 +- torch/_inductor/ir.py | 2 +- torch/_inductor/memory.py | 2 +- torch/_inductor/select_algorithm.py | 4 +--- torch/_inductor/tiling_utils.py | 2 +- torch/_library/infer_schema.py | 2 +- torch/_numpy/_dtypes.py | 2 +- torch/_numpy/_ndarray.py | 6 +++--- .../backend_config/_common_operator_config_utils.py | 2 +- torch/ao/quantization/pt2e/prepare.py | 4 ++-- torch/autograd/profiler_legacy.py | 4 ++-- torch/cuda/_device_limits.py | 4 ++-- torch/distributed/_serialization.py | 2 +- torch/distributed/checkpoint/_consolidate_hf_safetensors.py | 2 +- torch/distributed/checkpoint/quantized_hf_storage.py | 2 +- torch/distributed/checkpoint/state_dict.py | 2 +- torch/distributed/elastic/multiprocessing/tail_log.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 2 +- torch/distributed/pipelining/schedules.py | 4 ++-- torch/distributed/tensor/_ops/_pointwise_ops.py | 2 +- torch/export/dynamic_shapes.py | 2 +- torch/fx/experimental/sym_node.py | 2 +- torch/fx/node.py | 2 +- torch/nested/_internal/ops.py | 2 +- torch/optim/swa_utils.py | 4 ++-- torch/serialization.py | 4 ++-- torch/testing/_internal/distributed/distributed_test.py | 6 +++--- torch/testing/_internal/distributed/rpc/rpc_test.py | 2 +- 39 files changed, 50 insertions(+), 52 deletions(-) diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 706ec1768cd35..1469ca478a386 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -146,7 +146,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: backends = [ name - for name in _BACKENDS.keys() + for name in _BACKENDS if name not in _COMPILER_FNS or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] ] diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 77f5d6cb05a01..50a2667c12a25 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2587,7 +2587,7 @@ def update_used_symbols( real_script_obj ): flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] - for attr in flat_dict.keys(): + for attr in flat_dict: fake_attr_val = getattr( fake_script_obj.wrapped_obj, attr ) diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 89b6e3297933f..58de4fd20c953 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -443,7 +443,7 @@ def __init__( self.blocks_to_lifted_attrs = blocks_to_lifted_attrs # Populate methods for the standard operators. - for k in kind_to_standard_operators.keys(): + for k in kind_to_standard_operators: handler_func_name = ir_name_to_func_name(k) # Create an indirect function call: # convert__ --> lambda node: _convert_standard_operator(node) diff --git a/torch/_guards.py b/torch/_guards.py index b321c5f968b16..32b796d71eea7 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -904,7 +904,7 @@ def patch(**kwargs: Any) -> Generator[None, None, None]: prior = {} ctx = TracingContext.get() - for key in kwargs.keys(): + for key in kwargs: # KeyError on invalid entry prior[key] = getattr(ctx, key) for key, val in kwargs.items(): diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index 81dca605940e5..5a70a34f7b64b 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -164,7 +164,7 @@ def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> No self.extra_uses[new_node].add(updated_use) # Clean up erased nodes - for old_node in erased_merge_sets.keys(): + for old_node in erased_merge_sets: self.extra_deps[old_node].clear() self.extra_uses[old_node].clear() del self.merge_sets[old_node] diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index a227239356a61..bc8dba5119252 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -86,7 +86,7 @@ def swap_submodules( self, submodules: dict[str, Callable[..., Any]] ) -> dict[str, Callable[..., ValueRanges[Expr]]]: result: dict[str, Callable[..., ValueRanges[Expr]]] = {} - for key in submodules.keys(): + for key in submodules: if key == "get_index": result[key] = self.get_index elif "masked_subblock" in key: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index cf17bf2e9478b..9583494299265 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1681,7 +1681,7 @@ def set( if config.aot_inductor.emit_multi_arch_kernel: bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} - assert bin_type in bin_type_to_ext.keys(), ( + assert bin_type in bin_type_to_ext, ( "multi_arch_kernel_binary only supported in CUDA/XPU" ) base_path, _ = os.path.splitext(bin_path) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 02129fff24160..fad4ce84f2971 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -337,7 +337,7 @@ def process_args_for_input_shape(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ) or arg_type in (sympy.Integer, int, sympy.Float, float): write_dummy_scalar_ivalue(arg_name) elif arg_signature and arg_signature.startswith("tensordesc<"): @@ -719,7 +719,7 @@ def process_args(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ): code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 1f531a5d99ef5..41b12d05cd32e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -699,7 +699,7 @@ def get_block_args(self) -> list[ConstexprArg]: block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix self.block_args = list(block_names.keys()) - return [ConstexprArg(x) for x in block_names.keys()] + return [ConstexprArg(x) for x in block_names] def add_numel_to_args( self, argdefs: list[ArgName], signature: list[Any] diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 515f628c9938c..1c4a394d1eb28 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -430,7 +430,7 @@ def get_isa_from_cpu_capability( "avx2": "avx2", "avx512": "avx512", } - if capability in capability_to_isa_str.keys(): + if capability in capability_to_isa_str: # pyrefly: ignore [index-error] isa_str = capability_to_isa_str[capability] if isa_str == "INVALID_VEC_ISA": diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 668becdded469..50d986d48e6c2 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -192,7 +192,7 @@ def check_multiple_devices_or_any_cpu_nodes( ): return None - keys_repr = (repr(key) for key in device_node_mapping.keys()) + keys_repr = (repr(key) for key in device_node_mapping) return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2e89ea5ca461b..28e7f88d33986 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1590,7 +1590,7 @@ def maybe_propagate( schema_kwargs = {arg.name: arg for arg in schema.arguments} - for key in old_kwargs.keys(): + for key in old_kwargs: old_arg = old_kwargs[key] new_arg = new_kwargs[key] schema_arg = schema_kwargs[key] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b1a3071cb7ba4..53c12d0726044 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1534,7 +1534,7 @@ def py_cnst(val: object) -> Union[bool, float, int]: # "all" is desugared to `!any(!val)` } - assert reduction_type in rtypes_to_inits.keys(), ( + assert reduction_type in rtypes_to_inits, ( f"{reduction_type} not supported for zero-dimension tensors!" ) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 6f58b683ac22b..ed223de71c079 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -229,7 +229,7 @@ def assign_memory_planning_info_for_scheduler_buffers( # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) - for buf_name in name_to_buf.keys(): + for buf_name in name_to_buf: name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( size_alloc=sched_buf_to_size[buf_name][0], size_free=sched_buf_to_size[buf_name][1], diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dc4be650eccb4..41021b0fc8ed1 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3719,9 +3719,7 @@ def get_choice_info(choice): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] - out_dict = { - str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] - } + out_dict = {str((M, K, N)): [get_choice_info(choice) for choice in timings]} append_to_log(mm_filename, out_dict) diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 0c9305dc721dd..5b394b9ea9914 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -165,7 +165,7 @@ def find_coalesced_var( variables[v] = get_hint(v) zero_index = sympy_subs(index, variables) - for v in var_ranges.keys(): + for v in var_ranges: variables[v] = 1 try: new_val = sympy_subs(index, variables) diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 62bd70f65a510..cb3cfd1d6029f 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -291,7 +291,7 @@ def parse_return(annotation, error_fn): origin = typing.get_origin(annotation) if origin is not tuple: - if annotation not in SUPPORTED_RETURN_TYPES.keys(): + if annotation not in SUPPORTED_RETURN_TYPES: error_fn( f"Return has unsupported type {annotation}. " f"The valid types are: {SUPPORTED_RETURN_TYPES}." diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index a429d28f30cc3..134f7617b758a 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -248,7 +248,7 @@ def sctype_from_string(s): """Normalize a string value: a type 'name' or a typecode or a width alias.""" if s in _names: return _names[s] - if s in _name_aliases.keys(): + if s in _name_aliases: return _name_aliases[s] if s in _typecodes: return _typecodes[s] diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index f192a39dd0296..e3f3836754017 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -49,7 +49,7 @@ class Flags: def __init__(self, flag_to_value: dict): - assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + assert all(k in FLAGS for k in flag_to_value) # sanity check self._flag_to_value = flag_to_value def __getattr__(self, attr: str): @@ -59,7 +59,7 @@ def __getattr__(self, attr: str): raise AttributeError(f"No flag attribute '{attr}'") def __getitem__(self, key): - if key in SHORTHAND_TO_FLAGS.keys(): + if key in SHORTHAND_TO_FLAGS: key = SHORTHAND_TO_FLAGS[key] if key in FLAGS: try: @@ -76,7 +76,7 @@ def __setattr__(self, attr, value): super().__setattr__(attr, value) def __setitem__(self, key, value): - if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + if key in FLAGS or key in SHORTHAND_TO_FLAGS: raise NotImplementedError("Modifying flags is not implemented") else: raise KeyError(f"No flag key '{key}'") diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index ab44cfa09197d..25672e7e6ced9 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -678,7 +678,7 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf torch.nn.BatchNorm2d: nni.BNReLU2d, torch.nn.BatchNorm3d: nni.BNReLU3d, } - for bn in bn_to_fused_bn.keys(): + for bn in bn_to_fused_bn: fused_bn = bn_to_fused_bn[bn] # bn module + relu module fusion config bn_configs.append( diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 6eac69a96ba42..9f7767101aba6 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -217,7 +217,7 @@ def _get_edge_or_node_to_group_id( # means the observer of key should be shared with observer with value, by default it will # be shared with itself shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { - k: k for k in edge_or_node_to_qspec.keys() + k: k for k in edge_or_node_to_qspec } for edge_or_node, qspec in edge_or_node_to_qspec.items(): if isinstance(edge_or_node, torch.fx.Node): @@ -282,7 +282,7 @@ def _get_edge_or_node_to_group_id( # now that we get the sharing relations between all edges and nodes, we can assign group ids cur_group_id = 0 edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} - for edge_or_node in shared_with_map.keys(): + for edge_or_node in shared_with_map: root = _find_root_edge_or_node(edge_or_node, shared_with_map) if root not in edge_or_node_to_group_id: edge_or_node_to_group_id[root] = cur_group_id diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 9f60295655ddb..5dd26c0881370 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -296,9 +296,9 @@ def _get_record_key(record): f"Expected CPU and CUDA memory allocation handles to match, " f"but got {num_open_handles_cpu} CPU and {num_open_handles_cuda} CUDA" ) - for handle in cpu_memory_allocs.keys(): + for handle in cpu_memory_allocs: cpu_memory_allocs[handle] += record.cpu_memory_usage() - for handle in cuda_memory_allocs.keys(): + for handle in cuda_memory_allocs: cuda_memory_allocs[handle] += record.cuda_memory_usage() if num_open_handles_cpu == 0: # output event as a top-level memory event diff --git a/torch/cuda/_device_limits.py b/torch/cuda/_device_limits.py index 808d748c8f6eb..60aeedc8053ab 100644 --- a/torch/cuda/_device_limits.py +++ b/torch/cuda/_device_limits.py @@ -53,7 +53,7 @@ def get_fma_per_cycle_per_sm_cuda_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) @@ -96,7 +96,7 @@ def get_fma_per_cycle_per_sm_tensor_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index c13ba46ba5757..8f7043453be76 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -145,7 +145,7 @@ def _streaming_load( if pickle_module is None: pickle_module = pickle - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" zip_file = _PseudoZipFile() diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 9db89d038658a..9d70ab7c7400d 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -257,7 +257,7 @@ def _process_output_file( ) # Process each input safetensors file - for safetensors_file in input_files_data.keys(): + for safetensors_file in input_files_data: file_metadata = input_files_data[safetensors_file].metadata input_metadata_size = input_files_data[safetensors_file].metadata_size diff --git a/torch/distributed/checkpoint/quantized_hf_storage.py b/torch/distributed/checkpoint/quantized_hf_storage.py index 2cb189d515a8a..36f4ddf937fee 100644 --- a/torch/distributed/checkpoint/quantized_hf_storage.py +++ b/torch/distributed/checkpoint/quantized_hf_storage.py @@ -82,7 +82,7 @@ def _build_weight_scale_mapping(self, weight_map: dict[str, str]): # Store the complete weight map for file location lookups self._weight_map = weight_map - for tensor_name in weight_map.keys(): + for tensor_name in weight_map: if tensor_name.endswith(".weight_scale_inv"): weight_name = tensor_name.replace(".weight_scale_inv", ".weight") if weight_name in weight_map: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 9202851537fba..54a29c0bb3588 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -443,7 +443,7 @@ def _verify_state_dict( f"or load but optim state_dict is empty. {optim_state_dict}" ) - for key in model_state_dict.keys(): + for key in model_state_dict: if _FLAT_PARAM in key: raise RuntimeError( f"{key} contains {_FLAT_PARAM}. This can happen if the model " diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 034740810dcdd..a34ec1408be57 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -130,7 +130,7 @@ def __init__( self._log_line_prefixes = log_line_prefixes self._log_line_filter = log_line_filter self._finished_events: dict[int, Event] = { - local_rank: Event() for local_rank in log_files.keys() + local_rank: Event() for local_rank in log_files } self._futs: list[Future] = [] self._interval_sec = interval_sec diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 60e3f37a99919..96657eeea4106 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1549,7 +1549,7 @@ def _allgather_orig_param_states( fsdp_state._device_handle.memory_summary(), ) - output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states} dtype, state_buffers = _convert_all_state_info( fsdp_param_info, gathered_state_info, input_states, output_states diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 39da483fe002b..e60ae3b93ba63 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1637,7 +1637,7 @@ def _step_microbatches( # the stages in the pipeline_order all_prev_ranks: set[int] = set() all_next_ranks: set[int] = set() - for stage_index in stage_index_to_stage.keys(): + for stage_index in stage_index_to_stage: # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) if stage_index > 0: all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) @@ -3176,7 +3176,7 @@ def get_schedule_class(schedule_name: str): "ZBVZeroBubble": ScheduleZBVZeroBubble, "DualPipeV": ScheduleDualPipeV, } - lowercase_keys = {k.lower(): k for k in schedule_map.keys()} + lowercase_keys = {k.lower(): k for k in schedule_map} lowercase_schedule_name = schedule_name.lower() if lowercase_schedule_name not in lowercase_keys: raise ValueError( diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 084fa62706e0d..53b759e993c0d 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -618,7 +618,7 @@ def common_pointwise_strategy( return pointwise_strategy -for op in linear_pointwise_ops.keys(): +for op in linear_pointwise_ops: register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( linear_pointwise_strategy ) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 1e1f1f409857b..a9a018468cef1 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes( roots.add(c.root.__name__) # type: ignore[attr-defined] # check keys are existing dims or new roots - for k in shape_fixes.keys(): + for k in shape_fixes: assert k in name_to_dim or k in roots # cache so we don't produce multiple derived dim objects diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index d07d235e51321..e01cab57775c5 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1871,7 +1871,7 @@ def round_magic_impl(self, ndigits=None): setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) -for method in magic_methods.keys(): # type: ignore[assignment] +for method in magic_methods: # type: ignore[assignment] if method in only_bool_magic_methods: _make_user_magic(method, SymBool) continue diff --git a/torch/fx/node.py b/torch/fx/node.py index 1d72a75a6ccf4..272676a4e3a94 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -496,7 +496,7 @@ def insert_arg(self, idx: int, arg: Argument) -> None: _new_input_nodes: dict[Node, None] = {} _fx_map_arg(arg, _new_input_nodes.setdefault) - for new_use in _new_input_nodes.keys(): + for new_use in _new_input_nodes: if new_use not in self._input_nodes: self._input_nodes.setdefault(new_use) new_use.users.setdefault(self) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index a84a5b681d638..69c324ab726ec 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -143,7 +143,7 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None: name, arg_type = named_arg_type.split(": ") is_optional = arg_type.endswith("?") normalized_arg_type = arg_type[:-1] if is_optional else arg_type - if normalized_arg_type not in arg_type_check_fns.keys(): + if normalized_arg_type not in arg_type_check_fns: raise AssertionError(f"Unknown arg type: {normalized_arg_type}") if i >= len(args): diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 1ab915d27cd66..254560d8751ce 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -367,7 +367,7 @@ def update_bn( was_training = model.training model.train() - for module in momenta.keys(): + for module in momenta: module.momentum = None for input in loader: @@ -378,7 +378,7 @@ def update_bn( model(input) - for bn_module in momenta.keys(): + for bn_module in momenta: bn_module.momentum = momenta[bn_module] model.train(was_training) diff --git a/torch/serialization.py b/torch/serialization.py index ce5a74d92384e..ffa77cec732ed 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1250,7 +1250,7 @@ def persistent_id(self, obj): zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) # Write each tensor to a file named tensor/the_tensor_key in the zip archive - for key in serialized_storages.keys(): + for key in serialized_storages: name = f"data/{key}" storage = serialized_storages[key] num_bytes = storage.nbytes() @@ -1494,7 +1494,7 @@ def _get_wo_message(message: str) -> str: _check_dill_version(pickle_module) - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" with _open_file_like(f, "rb") as opened_file: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a14f670d788be..8cb9c929d8545 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7050,8 +7050,8 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) - self.assertTrue("in_split_size" in attrs.keys()) - self.assertTrue("out_split_size" in attrs.keys()) + self.assertTrue("in_split_size" in attrs) + self.assertTrue("out_split_size" in attrs) self.assertEqual(attrs.get("global_rank_start", -1), 0) self.assertEqual(attrs.get("global_rank_stride", -1), 1) @@ -9306,7 +9306,7 @@ def get_loss(model_output): "tuple": tuple, "dict": dict, } - for output_type in type_mapping.keys(): + for output_type in type_mapping: for _ in range(6): out = model(inp, output_type=output_type) loss = get_loss(out) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index b7c0dd17a1164..21464e514742c 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3282,7 +3282,7 @@ def test_debug_info(self): expected.update(autograd_info) # NB: Key ordering is only preserved in python 3.6+. So here, we # manually check keys are equal. - for key in expected.keys(): + for key in expected: self.assertIn(key, info.keys()) for key in info.keys(): From c08ce30d18303ff4e43d53ccb0c0c6e6b8bd1dae Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Wed, 5 Nov 2025 22:37:35 +0000 Subject: [PATCH 105/651] [ci][cpu] Update compiler to GCC-13 in jammy-aarch64 (#166849) This is needed because manylinux uses GCC-13 since #152825 As a result of the current compiler version mismatches, we've seen tests passing jammy-aarch64 pre-commit CI, but failing for wheels built in manylinux Related to: #166736 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166849 Approved by: https://github.com/robert-hardwick, https://github.com/malfet, https://github.com/Skylion007, https://github.com/atalman --- .ci/docker/build.sh | 8 ++++---- .ci/docker/common/install_gcc.sh | 4 ++-- .github/workflows/docker-builds.yml | 4 ++-- .github/workflows/inductor-perf-test-nightly-aarch64.yml | 2 +- .github/workflows/linux-aarch64.yml | 2 +- .github/workflows/operator_benchmark.yml | 2 +- test/cpp/aoti_abi_check/CMakeLists.txt | 4 ++++ test/cpp/api/CMakeLists.txt | 7 +++++++ 8 files changed, 22 insertions(+), 11 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 5257decb9d4d5..f0b9a788758ca 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -261,9 +261,9 @@ case "$tag" in PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11) + pytorch-linux-jammy-aarch64-py3.10-gcc13) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes @@ -281,9 +281,9 @@ case "$tag" in # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks) + pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes diff --git a/.ci/docker/common/install_gcc.sh b/.ci/docker/common/install_gcc.sh index 3b96bf6e0ed2f..df1c059bc3869 100644 --- a/.ci/docker/common/install_gcc.sh +++ b/.ci/docker/common/install_gcc.sh @@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then # Need the official toolchain repo to get alternate packages add-apt-repository ppa:ubuntu-toolchain-r/test apt-get update - apt-get install -y g++-$GCC_VERSION + apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 - + update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50 # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 4d0940094f541..941a045649f3a 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -77,11 +77,11 @@ jobs: pytorch-linux-noble-riscv64-py3.12-gcc14 ] include: - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 runner: linux.arm64.m7g.4xlarge - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 # Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358 diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index e16c8be79130d..46a1966570c63 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -72,7 +72,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.arm64.m7g.4xlarge build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" }, diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index 2b840a39a5c21..e6690b1043006 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -33,7 +33,7 @@ jobs: with: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge test-matrix: | { include: [ diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 40fb3b8d0c85f..758147f5fe18e 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -60,7 +60,7 @@ jobs: with: build-environment: linux-jammy-aarch64-py3.10 runner: linux.arm64.m7g.4xlarge - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 test-matrix: | { include: [ { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 1695e65cb4a1b..f1747acc31fc8 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -45,6 +45,10 @@ endif() # Disable unused-variable warnings for variables that are only used to test compilation target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable) target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable) +# Add -Wno-dangling-pointer for GCC 13 +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_aoti_abi_check -Wno-dangling-pointer) +endif() foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS}) foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 8261aae3b5607..a92832a4d04c9 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -70,6 +70,13 @@ if(NOT MSVC) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12) target_compile_options_if_supported(test_api "-Wno-error=nonnull") endif() + + # Add -Wno-error=array-bounds for GCC 13+ + # See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113239 + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_api "-Wno-error=array-bounds") + endif() + endif() if(INSTALL_TEST) From 85fab6c9b00bb6ba3a0d5e72596dfa4bf39fc998 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 6 Nov 2025 03:24:59 +0000 Subject: [PATCH 106/651] Fix duplicate benchmarking entries for addmm (#166652) There have been duplicate entries for addmm in dashboard. This PR fixes the duplicate entries issues Pull Request resolved: https://github.com/pytorch/pytorch/pull/166652 Approved by: https://github.com/yangw-dev --- benchmarks/operator_benchmark/pt/addmm_test.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py index a98628944b3e8..3e94a9cd7f3dc 100644 --- a/benchmarks/operator_benchmark/pt/addmm_test.py +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -53,10 +53,8 @@ def forward(self, input_one, mat1, mat2): return torch.addmm(input_one, mat1, mat2) -op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark) -op_bench.generate_pt_gradient_test( - addmm_long_configs + addmm_long_configs, AddmmBenchmark -) +op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark) +op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark) """Mircobenchmark for addbmm operator.""" @@ -107,9 +105,7 @@ def forward(self, input_one, batch1, batch2): ) op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark) -op_bench.generate_pt_gradient_test( - addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark -) +op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() From d31599f40bc580c170553ca6766163b41c427ed9 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 03:36:56 +0000 Subject: [PATCH 107/651] [7/N] Fix unused loop variables in tests (#167043) This PR continues to fix or remove unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167043 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_compiler_bisector.py | 2 +- test/functorch/test_control_flow.py | 14 ++++---------- test/inductor/test_flex_attention.py | 4 ++-- .../eager/test_bias_correction_eager.py | 2 +- test/quantization/fx/test_quantize_fx.py | 6 +++--- test/test_nn.py | 2 +- test/test_transformers.py | 2 +- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 8810a30aaf3b7..8ebf35f3f0d3f 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -283,7 +283,7 @@ def test_fn(): ) def test_bisect_pre_grad_graph(self): def f(x): - for i in range(5): + for _ in range(5): x = x + 1 return x.relu() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 5034661fa3e05..f83f059663149 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -942,9 +942,7 @@ def false_fn(x): b = torch.randn(4, requires_grad=True) c = torch.randn(4, requires_grad=True) - for pred, fn in zip( - [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] - ): + for pred in [torch.tensor(False), torch.tensor(True)]: with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", @@ -3066,13 +3064,9 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs): ).to(DEVICE) # Test 3 models: RNNScanList, RNNScanTensor, RNNLoop - models = [ - ("ScanList", RNNScanList), - ("ScanTensor", RNNScanTensor), - ("Loop", RNNLoop), - ] + models = [RNNScanList, RNNScanTensor, RNNLoop] - for model_name, model_class in models: + for model_class in models: # Create uncompiled model model_uc = model_class().to(DEVICE) uncompiled_grads, uncompiled_loss = run_test_and_get_grads_loss( @@ -7538,7 +7532,7 @@ def foo(x): inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3)) for inp in inps: - gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4)) + gm = make_fx(foo, tracing_mode="symbolic")(inp) self.assertExpectedInline( gm.code.strip(), """\ diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a1e5aa3cebc45..816d3b93ecfef 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5807,11 +5807,11 @@ def causal_mask(b, h, q_idx, kv_idx): from torch.utils._pytree import GetAttrKey - for key, tensor in tensors_with_keys: + for key, _tensor in tensors_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) - for key, value in context_with_keys: + for key, _value in context_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) diff --git a/test/quantization/eager/test_bias_correction_eager.py b/test/quantization/eager/test_bias_correction_eager.py index 5f0c475f934dd..071ea6e2a768f 100644 --- a/test/quantization/eager/test_bias_correction_eager.py +++ b/test/quantization/eager/test_bias_correction_eager.py @@ -39,7 +39,7 @@ def correct_artificial_bias_quantize(self, float_model, img_data): torch.ao.quantization.convert(artificial_model, inplace=True) # manually changing bias - for name, submodule in artificial_model.named_modules(): + for submodule in artificial_model.modules(): if type(submodule) in _supported_modules: x = get_param(submodule, "bias") weight = get_param(submodule, "weight") diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index faba2f5edc6a7..b33afc7a80363 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -9663,10 +9663,10 @@ def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, .set_global(get_default_qat_qconfig(qengine)) \ .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) - train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] - eval_output = [[torch.randint(0, 10, (12, 1))]] + train_indices = [[torch.randint(0, 10, (12, 12), device=device), torch.randn((12, 1), device=device)] for _ in range(2)] + eval_output = [[torch.randint(0, 10, (12, 1), device=device)]] - model = EmbeddingBagLinear().train() + model = EmbeddingBagLinear().to(device).train() prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, diff --git a/test/test_nn.py b/test/test_nn.py index 034cf51d49ff0..bedb4b22a01bd 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13516,7 +13516,7 @@ def compare_scaling(grads): # Should warning when parameters generator exhausted params = l.parameters() - for p in params: + for _p in params: pass with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/test/test_transformers.py b/test/test_transformers.py index cc82cbff2a46f..ad7ae56307eb1 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2857,7 +2857,7 @@ def test_cudnn_attention_broken_166211(self): # https://github.com/pytorch/pytorch/issues/166211#issue-3551350377 shape = (20, 4, 4, 32) scale = 10 - for i in range(100): + for _ in range(100): q = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale k = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale v = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale From 981dd718939ae2413c217c071e364715dbdbf8d6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sat, 1 Nov 2025 15:49:12 -0700 Subject: [PATCH 108/651] Refactor: extract OperatorArgsKwargsView from parseIValuesToPyArgsKwargs (#166368) Intended to make it easier to reuse this logic for processing operator arguments as IValues in following PR(s). Testing: python test/test_python_dispatch.py (broke during development, seems to work now) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166368 Approved by: https://github.com/albanD --- torch/csrc/autograd/python_variable.cpp | 193 ++++++++++++++++++------ 1 file changed, 151 insertions(+), 42 deletions(-) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 946a8d5f1d367..837ba93d1cc28 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -51,14 +51,101 @@ using namespace at; using namespace torch; using namespace torch::autograd; -std::pair parseIValuesToPyArgsKwargs( - const c10::OperatorHandle& op, - const std::vector& arguments) { - TORCH_CHECK( - PyGILState_Check(), - "GIL must be held before you call parseIValuesToPyArgsKwargs"); - const auto& schema = op.schema(); - py::dict kwargs; +namespace { +class OperatorArgsKwargsView { + public: + OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments); + using args_iterator = const c10::IValue*; + + args_iterator args_begin() const { + return arguments_.data(); + } + + args_iterator args_end() const { + return arguments_.data() + positional_default_start_; + } + + auto num_positional_args() const { + return positional_default_start_; + } + + auto kwarg_start_index() const { + return first_non_default_kwarg_; + } + + struct kwargs_iterator { + kwargs_iterator() = default; + kwargs_iterator(const OperatorArgsKwargsView* parent, size_t current) + : parent_(parent), current_(current) {} + + kwargs_iterator(const kwargs_iterator&) = default; + kwargs_iterator& operator=(const kwargs_iterator&) = default; + + kwargs_iterator& operator++() { + do { + current_++; + } while (current_ < parent_->arguments_.size() && + parent_->is_default(current_)); + return *this; + } + + kwargs_iterator operator++(int) { + auto copy = *this; + ++(*this); + return copy; + } + + const c10::IValue& operator*() const { + return parent_->arguments_[current_]; + } + + const c10::IValue* operator->() const { + return &operator*(); + } + + int64_t underlying_index() const { + return current_; + } + + bool operator==(const kwargs_iterator& rhs) const { + return parent_ == rhs.parent_ && current_ == rhs.current_; + } + + bool operator!=(const kwargs_iterator& rhs) { + return !(*this == rhs); + } + + private: + const OperatorArgsKwargsView* parent_ = nullptr; + size_t current_ = 0; + }; + + kwargs_iterator kwargs_begin() const { + return kwargs_iterator(this, first_non_default_kwarg_); + } + + kwargs_iterator kwargs_end() const { + return kwargs_iterator(this, arguments_.size()); + } + + private: + bool is_default(size_t idx) const { + const auto& arg = op_.schema().arguments()[idx]; + if (!arg.default_value().has_value()) { + return false; + } + const auto& default_ivalue = *arg.default_value(); + const auto& ivalue = arguments_[idx]; + if (default_ivalue != ivalue) { + return false; + } + return true; + } + + const c10::OperatorHandle& op_; + c10::ArrayRef arguments_; // About all the pointers: // // f(int x, int y = 0, *, int z = 0) @@ -66,45 +153,63 @@ std::pair parseIValuesToPyArgsKwargs( // ^- kwarg_only_start // ^- positional_default_start // ^- 0 + int64_t positional_default_start_; + int64_t first_non_default_kwarg_; +}; +OperatorArgsKwargsView::OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments) + : op_(op), arguments_(arguments) { // Find the split point between kwarg-only and regular. Since most functions // don't have kwarg-only arguments, it is more efficient to scan from the // right (but ideally, this would just be precomputed in FunctionSchema // itself). (NB: minus one in the loop is because we're testing if the // *next* argument is kwarg-only before we advance the starting index) - int64_t kwarg_only_start = static_cast(arguments.size()); + const int64_t signed_arguments_size = static_cast(arguments.size()); + int64_t kwarg_only_start = signed_arguments_size; for (; kwarg_only_start > 0; kwarg_only_start--) { - const auto& arg = schema.arguments()[kwarg_only_start - 1]; + const auto& arg = op.schema().arguments()[kwarg_only_start - 1]; if (!arg.kwarg_only()) { break; } } // Find the first positional argument that isn't defaulted - auto is_default = [&](size_t idx) -> bool { - const auto& arg = schema.arguments()[idx]; - if (!arg.default_value().has_value()) { - return false; - } - const auto& default_ivalue = *arg.default_value(); - const auto& ivalue = arguments[idx]; - if (default_ivalue != ivalue) { - return false; + positional_default_start_ = kwarg_only_start; + for (; positional_default_start_ > 0; positional_default_start_--) { + if (!is_default(positional_default_start_ - 1)) { + break; } - return true; - }; + } - int64_t positional_default_start = kwarg_only_start; - for (; positional_default_start > 0; positional_default_start--) { - if (!is_default(positional_default_start - 1)) { + // kwargs_iterator will skip default kwargs when incremented, but we + // need to skip any initial run of default kwargs ourselves. + first_non_default_kwarg_ = kwarg_only_start; + for (; first_non_default_kwarg_ < signed_arguments_size; + ++first_non_default_kwarg_) { + if (!is_default(first_non_default_kwarg_)) { break; } } +} +} // namespace - auto args = - py::reinterpret_steal(PyTuple_New(positional_default_start)); +std::pair parseIValuesToPyArgsKwargs( + const c10::OperatorHandle& op, + const std::vector& arguments) { + TORCH_CHECK( + PyGILState_Check(), + "GIL must be held before you call parseIValuesToPyArgsKwargs"); + const auto& schema = op.schema(); + py::dict kwargs; - auto schemaAwareToPyObject = [&](size_t idx) -> py::object { + OperatorArgsKwargsView args_kwargs(op, arguments); + auto args = py::reinterpret_steal( + PyTuple_New(args_kwargs.num_positional_args())); + + auto schemaAwareToPyObject = + [&schema](size_t idx, const c10::IValue& argument) -> py::object { const auto& arg = schema.arguments()[idx]; auto match = [&](c10::TypeKind kind) { const auto& t = arg.real_type(); @@ -116,38 +221,42 @@ std::pair parseIValuesToPyArgsKwargs( } return false; }; - if (arguments[idx].isNone()) { + if (argument.isNone()) { return py::none(); } else if (match(c10::ScalarTypeType::Kind)) { - auto* obj = - getTHPDtype(static_cast(arguments[idx].toInt())); + auto* obj = getTHPDtype(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::LayoutType::Kind)) { - auto* obj = - getTHPLayout(static_cast(arguments[idx].toInt())); + auto* obj = getTHPLayout(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::MemoryFormatType::Kind)) { - return py::cast(static_cast(arguments[idx].toInt())); + return py::cast(static_cast(argument.toInt())); } else { - return torch::jit::toPyObject(arguments[idx]); + return torch::jit::toPyObject(argument); } }; // Populate positional arguments - for (const auto idx : c10::irange(positional_default_start)) { + size_t idx = 0; + for (auto argument_it = args_kwargs.args_begin(); + argument_it != args_kwargs.args_end(); + ++argument_it) { PyTuple_SET_ITEM( - args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); + args.ptr(), + idx, + schemaAwareToPyObject(idx, *argument_it).release().ptr()); + idx++; } // Populate keyword arguments - for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) { - // But don't populate default keyword arguments - if (is_default(idx)) - continue; - const auto& arg = schema.arguments()[idx]; - kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx); + for (auto argument_it = args_kwargs.kwargs_begin(); + argument_it != args_kwargs.kwargs_end(); + ++argument_it) { + const auto& arg = schema.arguments()[argument_it.underlying_index()]; + kwargs[py::cast(arg.name())] = + schemaAwareToPyObject(argument_it.underlying_index(), *argument_it); } return std::make_pair(std::move(args), std::move(kwargs)); } From f72772b184ffbe82bba2412787955587a66233de Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 5 Nov 2025 11:41:31 -0800 Subject: [PATCH 109/651] [PP] make runtime dbg log print custom actions (#167113) Previously the log only printed if the default implementation for an action was used, now it prints before dispatching to custom registered actions. Tested by running on autoparallel graph runner and observing forward pass action logged Pull Request resolved: https://github.com/pytorch/pytorch/pull/167113 Approved by: https://github.com/sanketpurandare, https://github.com/Skylion007 --- torch/distributed/pipelining/schedules.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index e60ae3b93ba63..44569427f8db2 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -2033,12 +2033,6 @@ def _perform_action(action: _Action) -> None: is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage - logger.debug( - "_PipelineScheduleRuntime running time_step %d, action %s", - time_step, - action, - ) - # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be # safe to use instead. @@ -2191,6 +2185,11 @@ def _perform_action(action: _Action) -> None: # count either full_backward or backward_weight together, to determine when to sync DP grads self.backward_counter.clear() for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) try: with record_function(_get_profiler_function_name(action)): if action.computation_type in self._comp_type_to_function_map: From c3c36534187d49da7ca2c680a641430eb9cfc404 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 04:32:14 +0000 Subject: [PATCH 110/651] [1/N] Add return types of Python functions (#167162) This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff `ANN` rules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167162 Approved by: https://github.com/Lucaskabela --- torch/nn/attention/__init__.py | 2 +- torch/nn/attention/_utils.py | 2 +- torch/nn/attention/bias.py | 4 +-- .../experimental/_paged_attention.py | 2 +- torch/nn/attention/flex_attention.py | 12 +++---- torch/nn/backends/thnn.py | 2 +- torch/nn/cpp.py | 12 +++---- torch/nn/modules/module.py | 2 +- torch/nn/parallel/data_parallel.py | 2 +- torch/nn/parameter.py | 10 +++--- torch/nn/parameter.pyi | 4 +-- .../expanded_weights_impl.py | 14 ++++---- .../expanded_weights_utils.py | 2 +- torch/nn/utils/parametrizations.py | 2 +- torch/nn/utils/parametrize.py | 2 +- torch/nn/utils/prune.py | 32 ++++++++++--------- torch/optim/_adafactor.py | 10 +++--- torch/optim/_functional.py | 2 +- torch/optim/_muon.py | 4 +-- torch/optim/adadelta.py | 8 ++--- torch/optim/adagrad.py | 10 +++--- torch/optim/adam.py | 8 ++--- torch/optim/adamax.py | 8 ++--- torch/optim/adamw.py | 4 +-- torch/optim/asgd.py | 8 ++--- torch/optim/lbfgs.py | 6 ++-- torch/optim/lr_scheduler.py | 28 ++++++++-------- torch/optim/nadam.py | 8 ++--- torch/optim/optimizer.py | 2 +- torch/optim/radam.py | 8 ++--- torch/optim/rmsprop.py | 8 ++--- torch/optim/rprop.py | 8 ++--- torch/optim/sgd.py | 8 ++--- torch/optim/sparse_adam.py | 2 +- torch/optim/swa_utils.py | 16 ++++++---- 35 files changed, 134 insertions(+), 128 deletions(-) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5e6e0fa5fae3b..a115d32c6e2c8 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -90,7 +90,7 @@ def _cur_sdpa_kernel_backends(with_priority: bool = False): return backends -def _sdpa_kernel(backends: Iterable, set_priority: bool = False): +def _sdpa_kernel(backends: Iterable, set_priority: bool = False) -> None: for name, val in _backend_names.items(): enabled = getattr(SDPBackend, val) in backends getattr(torch._C, f"_set_sdp_use_{name}")(enabled) diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index a91045b92c13e..86f7c29f5313a 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -40,7 +40,7 @@ def _validate_sdpa_input( dropout_p=0.0, is_causal=False, scale=None, -): +) -> None: if query.dtype != key.dtype or query.dtype != value.dtype: raise ValueError( f"Expected query, key, and value to have the same dtype, " diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 551a57e6963e0..0cb256ad36f7f 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -117,7 +117,7 @@ class CausalBias(torch.Tensor): .. warning:: This class is a prototype and subject to change. """ - def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None: """ Initializes the CausalBias instance with a specified variant and sequence lengths. @@ -296,7 +296,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return cls._dispatch(*args, **kwargs) return super().__torch_function__(func, types, args, kwargs) - def __repr__(self): # type:ignore[override] + def __repr__(self) -> str: # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 70eadcdadfaa0..2e0ded6063aef 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -40,7 +40,7 @@ def __init__( page_size: int, max_batch_size: int, device: str = "cuda", - ): + ) -> None: # number of pages self.n_pages = n_pages diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index b79b86a29afb6..be49549e5740e 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -550,7 +550,7 @@ def __init__( full_q_indices: Optional[Tensor], BLOCK_SIZE: tuple[int, int], mask_mod: _mask_mod_signature, - ): + ) -> None: if kv_indices.dim() < 2: raise RuntimeError("BlockMask must have at least 2 dimensions") assert kv_num_blocks is not None, "kv_num_blocks must be provided" @@ -682,7 +682,7 @@ def shape(self): *batch_dims, _, _ = self.kv_indices.shape return tuple(batch_dims) + self.seq_lengths - def __str__(self): + def __str__(self) -> str: s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" mask_str = self.to_string().strip() s += mask_str @@ -760,7 +760,7 @@ def causal_mask(b, h, q_idx, kv_idx): compute_q_blocks=self.q_indices is not None, ) - def __repr__(self): + def __repr__(self) -> str: def shape_or_none(x: Optional[torch.Tensor]): return x.shape if x is not None else None @@ -864,7 +864,7 @@ def create_block_vis(*batch_idx): vis = ", ".join(reversed(descriptors)) + "\n" - def summarize_section(section): + def summarize_section(section) -> str: percentage = section.float().mean().item() if percentage == 1: return "█" @@ -1289,7 +1289,7 @@ def _apply_kernel_options( return kernel_options -def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): +def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None: if query.size(-1) != key.size(-1): raise ValueError( f"Expect query and key/value to have the same embedding dimension " @@ -1297,7 +1297,7 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): ) -def _validate_device(query: Tensor, key: Tensor, value: Tensor): +def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None: """TODO: Remove once non cuda/cpu devices support is added We only need to check query since we have already that q,k,v are on the same device """ diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 8564153ece233..c56e923a84383 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -2,5 +2,5 @@ # this is for historical pickle deserialization, it is not used otherwise -def _get_thnn_function_backend(): +def _get_thnn_function_backend() -> None: pass diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index e447284ad82ba..b4ffd188cd39a 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -14,7 +14,7 @@ class OrderedDictWrapper: so using properties does not work. """ - def __init__(self, cpp_module, attr): + def __init__(self, cpp_module, attr) -> None: self.cpp_module = cpp_module self.attr = attr @@ -37,10 +37,10 @@ def values(self): def __iter__(self): return self.cpp_dict.__iter__() - def __len__(self): + def __len__(self) -> int: return self.cpp_dict.__len__() - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.cpp_dict.__contains__(key) def __getitem__(self, key): @@ -50,7 +50,7 @@ def __getitem__(self, key): class ModuleWrapper(nn.Module): """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" - def __init__(self, cpp_module): + def __init__(self, cpp_module) -> None: # Assign before the super class constructor so ``self.training`` can be # assigned to in the super class constructor. self.cpp_module = cpp_module @@ -83,8 +83,8 @@ def training(self): return self.cpp_module.training @training.setter - def training(self, mode): + def training(self, mode) -> None: self.cpp_module.train(mode) - def __repr__(self): + def __repr__(self) -> str: return self.cpp_module.__repr__() diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f7e3d2f262def..10a240e3a9cf7 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -3040,7 +3040,7 @@ def _replicate_for_data_parallel(self): return replica - def compile(self, *args, **kwargs): + def compile(self, *args, **kwargs) -> None: """ Compile this Module's forward using :func:`torch.compile`. diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 9a0f4973d31b2..9aaa9b4a92e6d 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -30,7 +30,7 @@ def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: device_ids = [_get_device_index(x, True) for x in device_ids] dev_props = _get_devices_properties(device_ids) - def warn_imbalance(get_prop): + def warn_imbalance(get_prop) -> bool: values = [get_prop(props) for props in dev_props] min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c03c85f48fc35..64e9d8c2d80f2 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -18,7 +18,7 @@ # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Parameter: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_param", False @@ -82,7 +82,7 @@ def __deepcopy__(self, memo): return result # pyrefly: ignore [bad-override] - def __repr__(self): + def __repr__(self) -> str: return "Parameter containing:\n" + super().__repr__() def __reduce_ex__(self, proto): @@ -125,7 +125,7 @@ class UninitializedTensorMixin: torch._has_compatible_shallow_copy_type, ] - def materialize(self, shape, device=None, dtype=None): + def materialize(self, shape, device=None, dtype=None) -> None: r"""Create a Parameter or Tensor with the same properties of the uninitialized one. Given a shape, it materializes a parameter in the same device @@ -163,7 +163,7 @@ def share_memory_(self): "`module.share_memory()`." ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" def __reduce_ex__(self, proto): @@ -235,7 +235,7 @@ def __deepcopy__(self, memo): # Metaclass to combine _TensorMeta and the instance check override for Buffer. class _BufferMeta(torch._C._TensorMeta): # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Buffer: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_buffer", False diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index a17821c2b16c1..3d1cddb7e8b8b 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -25,7 +25,7 @@ class Buffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... class UninitializedBuffer(Tensor): persistent: bool @@ -34,7 +34,7 @@ class UninitializedBuffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... def materialize( self, shape: tuple[int, ...], diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index cfb1d99ac30ec..58ef67e06148a 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -37,10 +37,10 @@ # all of the RNN decomps run linear with the batch dimension second, even if batch_first was set @contextmanager def batch_second(args, kwargs): - def set_batch_second(ew): + def set_batch_second(ew) -> None: ew.set_batch_first(False) - def reset_batch_first(ew): + def reset_batch_first(ew) -> None: ew.set_batch_first(True) tree_map_only(ExpandedWeight, set_batch_second, args) @@ -55,10 +55,10 @@ def reset_batch_first(ew): # to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch @contextmanager def allow_smaller_batches(args, kwargs): - def allow(ew): + def allow(ew) -> None: ew.set_allow_smaller_batches(True) - def reset(ew): + def reset(ew) -> None: ew.set_allow_smaller_batches(False) tree_map_only(ExpandedWeight, allow, args) @@ -102,7 +102,7 @@ def decorator(autograd_func): # # Needs to be a tensor subclass to allow reparameterization class ExpandedWeight(torch.Tensor): - def __init__(self, orig_weight, batch_size, loss_reduction): + def __init__(self, orig_weight, batch_size, loss_reduction) -> None: self.batch_size = batch_size self.batch_first = True self.allow_smaller_batches = False @@ -179,8 +179,8 @@ def data_ptr(self): def get_device(self): return self.orig_weight.get_device() - def set_allow_smaller_batches(self, is_allow_smaller_batches): + def set_allow_smaller_batches(self, is_allow_smaller_batches) -> None: self.allow_smaller_batches = is_allow_smaller_batches - def set_batch_first(self, is_batch_first=True): + def set_batch_first(self, is_batch_first=True) -> None: self.batch_first = is_batch_first diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index ec6d55305fb46..eacd717873ec2 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -123,7 +123,7 @@ def maybe_scale_by_batch_size(grad_sample, expanded_weight): return grad_sample -def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn) -> None: unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) if isinstance(maybe_expanded_weight, ExpandedWeight): grad_sample_contribution = maybe_scale_by_batch_size( diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 7706be61e39f1..59044b72b96cd 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -388,7 +388,7 @@ def _weight_norm_compat_hook( missing_keys, unexpected_keys, error_msgs, - ): + ) -> None: g_key = f"{prefix}{name}_g" v_key = f"{prefix}{name}_v" if g_key in state_dict and v_key in state_dict: diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 88eeb3aaf50c3..b9a1140e43f71 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -72,7 +72,7 @@ def cached(): _cache = {} -def _register_parameter_or_buffer(module, name, X): +def _register_parameter_or_buffer(module, name, X) -> None: if isinstance(X, Parameter): module.register_parameter(name, X) else: diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 3c1a800085951..827bf19ed4bea 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -231,7 +231,7 @@ def prune(self, t, default_mask=None, importance_scores=None): default_mask = default_mask if default_mask is not None else torch.ones_like(t) return t * self.compute_mask(importance_scores, default_mask=default_mask) - def remove(self, module): + def remove(self, module) -> None: r"""Remove the pruning reparameterization from a module. The pruned parameter named ``name`` remains permanently pruned, @@ -269,7 +269,7 @@ class PruningContainer(BasePruningMethod): them. """ - def __init__(self, *args): + def __init__(self, *args) -> None: self._pruning_methods: tuple[BasePruningMethod, ...] = () if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name @@ -284,7 +284,7 @@ def __init__(self, *args): for method in args: self.add_pruning_method(method) - def add_pruning_method(self, method): + def add_pruning_method(self, method) -> None: r"""Add a child pruning ``method`` to the container. Args: @@ -303,7 +303,7 @@ def add_pruning_method(self, method): # if all checks passed, add to _pruning_methods tuple self._pruning_methods += (method,) # type: ignore[operator] - def __len__(self): + def __len__(self) -> int: return len(self._pruning_methods) def __iter__(self): @@ -449,7 +449,7 @@ class RandomUnstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -506,7 +506,7 @@ class L1Unstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -574,7 +574,7 @@ class RandomStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, dim=-1): + def __init__(self, amount, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -682,7 +682,7 @@ class LnStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, n, dim=-1): + def __init__(self, amount, n, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -799,7 +799,7 @@ def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: i class CustomFromMask(BasePruningMethod): PRUNING_TYPE = "global" - def __init__(self, mask): + def __init__(self, mask) -> None: self.mask = mask def compute_mask(self, t, default_mask): @@ -1025,7 +1025,9 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None): return module -def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): +def global_unstructured( + parameters, pruning_method, importance_scores=None, **kwargs +) -> None: r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -1212,7 +1214,7 @@ def remove(module, name): ) -def is_pruned(module): +def is_pruned(module) -> bool: r"""Check if a module is pruned by looking for pruning pre-hooks. Check whether ``module`` is pruned by looking for @@ -1241,7 +1243,7 @@ def is_pruned(module): return False -def _validate_pruning_amount_init(amount): +def _validate_pruning_amount_init(amount) -> None: r"""Validate helper to check the range of amount at init. Args: @@ -1271,7 +1273,7 @@ def _validate_pruning_amount_init(amount): ) -def _validate_pruning_amount(amount, tensor_size): +def _validate_pruning_amount(amount, tensor_size) -> None: r"""Validate that the pruning amount is meaningful wrt to the size of the data. Validation helper to check that the amount of parameters to prune @@ -1295,7 +1297,7 @@ def _validate_pruning_amount(amount, tensor_size): ) -def _validate_structured_pruning(t): +def _validate_structured_pruning(t) -> None: r"""Validate that the tensor to be pruned is at least 2-Dimensional. Validation helper to check that the tensor to be pruned is multi- @@ -1342,7 +1344,7 @@ def _compute_nparams_toprune(amount, tensor_size): return round(amount * tensor_size) -def _validate_pruning_dim(t, dim): +def _validate_pruning_dim(t, dim) -> None: r"""Validate that the pruning dimension is within the bounds of the tensor dimension. Args: diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 4def193daf190..c417b354429b5 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -32,7 +32,7 @@ def __init__( *, foreach: Optional[bool] = None, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -77,7 +77,7 @@ def _init_group( col_vars, variances, state_steps, - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -349,7 +349,7 @@ def _single_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Grad scaling should occur outside of optimizer.step()") @@ -473,7 +473,7 @@ def _multi_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def adafactor( eps1: float, eps2: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adafactor algorithm computation. See :class:`~torch.optim.Adafactor` for details. diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 9b2c76700b356..ba97bc9979378 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -33,7 +33,7 @@ def sparse_adam( beta2: float, lr: float, maximize: bool, -): +) -> None: r"""Functional API that performs Sparse Adam algorithm computation. See :class:`~torch.optim.SparseAdam` for details. diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index 7b7167a40fc1c..5b7b9892daf3a 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -141,7 +141,7 @@ def _init_group( params_with_grad: list[Tensor], grads: list[Tensor], muon_momentum_bufs: list[Tensor], - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -337,7 +337,7 @@ def muon( eps: float, adjust_lr_fn: Optional[str], has_complex: bool, -): +) -> None: r"""Functional API that performs Muon algorithm computation. See :class:`~torch.optim.Muon` for details. diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4a893026451ae..75ac77790e309 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -38,7 +38,7 @@ def __init__( capturable: bool = False, maximize: bool = False, differentiable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -257,7 +257,7 @@ def _single_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( @@ -317,7 +317,7 @@ def _multi_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -427,7 +427,7 @@ def adadelta( eps: float, weight_decay: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adadelta algorithm computation. See :class:`~torch.optim.Adadelta` for details. diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 4d2523b2a16af..519900ab5da63 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -38,7 +38,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -116,7 +116,7 @@ def __setstate__(self, state): float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) ) - def share_memory(self): + def share_memory(self) -> None: """Calls tensor.share_memory_() on the state sum tensors.""" for group in self.param_groups: for p in group["params"]: @@ -261,7 +261,7 @@ def adagrad( lr_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. @@ -336,7 +336,7 @@ def _single_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -404,7 +404,7 @@ def _multi_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") if grad_scale is not None or found_inf is not None: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 5ceadccce86a5..6b8fd5b7e70f6 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -47,7 +47,7 @@ def __init__( differentiable: bool = False, fused: Optional[bool] = None, decoupled_weight_decay: bool = False, - ): + ) -> None: if isinstance(lr, Tensor): if foreach and not capturable: raise ValueError( @@ -365,7 +365,7 @@ def _single_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -572,7 +572,7 @@ def _multi_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if len(params) == 0: return @@ -925,7 +925,7 @@ def adam( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 76d784d6ea764..264451dbb4091 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -239,7 +239,7 @@ def _single_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -319,7 +319,7 @@ def _multi_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -441,7 +441,7 @@ def adamax( beta2: float, lr: float, weight_decay: float, -): +) -> None: r"""Functional API that performs adamax algorithm computation. See :class:`~torch.optim.Adamax` for details. diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 0558cbddd883b..2c968fabb698c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -33,7 +33,7 @@ def __init__( capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: super().__init__( params, lr, @@ -152,7 +152,7 @@ def adamw( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 0008694bda18b..0af7f9b4e6f6d 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -211,7 +211,7 @@ def _single_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -292,7 +292,7 @@ def _multi_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -442,7 +442,7 @@ def asgd( t0: float, alpha: float, weight_decay: float, -): +) -> None: r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index ae4b286ffa225..3d138f6a43f76 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -254,7 +254,7 @@ def __init__( tolerance_change: float = 1e-9, history_size: int = 100, line_search_fn: Optional[str] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -304,7 +304,7 @@ def _gather_flat_grad(self): views.append(view) return torch.cat(views, 0) - def _add_grad(self, step_size, update): + def _add_grad(self, step_size, update) -> None: offset = 0 for p in self._params: if torch.is_complex(p): @@ -319,7 +319,7 @@ def _add_grad(self, step_size, update): def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] - def _set_param(self, params_data): + def _set_param(self, params_data) -> None: for p, pdata in zip(self._params, params_data, strict=True): p.copy_(pdata) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 71dcb6129a8ec..6426283e6542c 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -89,7 +89,9 @@ def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]: ] -def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): +def _update_param_group_val( + param_group: dict[str, Any], key: str, val: float | Tensor +) -> None: """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ @@ -196,7 +198,7 @@ def state_dict(self) -> dict[str, Any]: key: value for key, value in self.__dict__.items() if key != "optimizer" } - def load_state_dict(self, state_dict: dict[str, Any]): + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: @@ -288,7 +290,7 @@ def step(self, epoch: Optional[int] = None) -> None: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2) self._update_lr(epoch) - def _update_lr(self, epoch: Optional[int] = None): + def _update_lr(self, epoch: Optional[int] = None) -> None: with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 @@ -339,7 +341,7 @@ def __exit__(self, type, value, traceback) -> None: class _initial_mode: - def __init__(self, o: LRScheduler): + def __init__(self, o: LRScheduler) -> None: self.o = o def __enter__(self): @@ -1180,7 +1182,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def recursive_undo(self, sched=None): + def recursive_undo(self, sched=None) -> None: """ Recursively undo any step performed by the initialisation of schedulers. @@ -1659,7 +1661,7 @@ def __init__( cooldown: int = 0, min_lr: Union[list[float], float] = 0, eps: float = 1e-8, - ): # noqa: D107 + ) -> None: # noqa: D107 if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor @@ -1691,7 +1693,7 @@ def __init__( ) self._reset() - def _reset(self): + def _reset(self) -> None: """Reset num_bad_epochs counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 @@ -1724,7 +1726,7 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over self._last_lr = _param_groups_val_list(self.optimizer, "lr") - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch) -> None: if len(self.optimizer.param_groups) != len(self.min_lrs): if self.default_min_lr is None: raise RuntimeError( @@ -1765,7 +1767,7 @@ def _is_better(self, a, best): # noqa: D102 else: # mode == 'max' and epsilon_mode == 'abs': return a > best + self.threshold - def _init_is_better(self, mode, threshold, threshold_mode): + def _init_is_better(self, mode, threshold, threshold_mode) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: @@ -1904,7 +1906,7 @@ def __init__( base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -1970,7 +1972,7 @@ def __init__( super().__init__(optimizer, last_epoch) self.base_lrs = base_lrs - def _init_scale_fn(self): + def _init_scale_fn(self) -> None: if self._scale_fn_custom is not None: return if self.mode == "triangular": @@ -2155,7 +2157,7 @@ def __init__( T_mult: int = 1, eta_min: float = 0.0, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): @@ -2407,7 +2409,7 @@ def __init__( final_div_factor: float = 1e4, three_phase: bool = False, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Validate optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 508648a65c14a..f83cd4b85d02f 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -44,7 +44,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -297,7 +297,7 @@ def _single_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -397,7 +397,7 @@ def _multi_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def nadam( weight_decay: float, momentum_decay: float, eps: float, -): +) -> None: r"""Functional API that performs NAdam algorithm computation. See :class:`~torch.optim.NAdam` for details. diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 6a336fa5bab70..c42ea3cfb02d5 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -204,7 +204,7 @@ def _device_dtype_check_for_fused( ) -def _view_as_real(params, *state_and_grads): +def _view_as_real(params, *state_and_grads) -> None: for i, p in enumerate(params): if torch.is_complex(p): params[i] = torch.view_as_real(params[i]) diff --git a/torch/optim/radam.py b/torch/optim/radam.py index e13e6806e43a7..db69bbb01a042 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -42,7 +42,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -270,7 +270,7 @@ def _single_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -377,7 +377,7 @@ def _multi_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -586,7 +586,7 @@ def radam( lr: float, weight_decay: float, eps: float, -): +) -> None: r"""Functional API that performs RAdam algorithm computation. See :class:`~torch.optim.RAdam` for details. diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 04981d517d1ef..364068ecc9ab3 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -41,7 +41,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -280,7 +280,7 @@ def _single_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -357,7 +357,7 @@ def _multi_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -495,7 +495,7 @@ def rmsprop( weight_decay: float, momentum: float, centered: bool, -): +) -> None: r"""Functional API that performs rmsprop algorithm computation. See :class:`~torch.optim.RMSProp` for details. diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 8ad7faf130e39..c9e1d5eabaeee 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -235,7 +235,7 @@ def _single_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: for i, param in enumerate(params): grad = grads[i] grad = grad if not maximize else -grad @@ -306,7 +306,7 @@ def _multi_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -428,7 +428,7 @@ def rprop( step_size_max: float, etaminus: float, etaplus: float, -): +) -> None: r"""Functional API that performs rprop algorithm computation. See :class:`~torch.optim.Rprop` for details. diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 9c2c5a0eab3d0..63c80d645cd08 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, differentiable: bool = False, fused: Optional[bool] = None, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if lr < 0.0: @@ -267,7 +267,7 @@ def sgd( dampening: float, nesterov: bool, maximize: bool, -): +) -> None: r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. @@ -333,7 +333,7 @@ def _single_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -394,7 +394,7 @@ def _multi_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index ca87e87ce8674..ed58c93181ae2 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -19,7 +19,7 @@ def __init__( betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 < lr: diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 254560d8751ce..ebe3e07025957 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -43,7 +43,9 @@ def get_ema_multi_avg_fn(decay=0.999): ) @torch.no_grad() - def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + def ema_update( + ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _ + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( ema_param_list[0] @@ -64,7 +66,7 @@ def swa_update( averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, num_averaged: Union[Tensor, int], - ): + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( averaged_param_list[0] @@ -227,7 +229,7 @@ def __init__( Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] ] = None, use_buffers=False, - ): # noqa: D107 + ) -> None: # noqa: D107 super().__init__() if avg_fn is not None and multi_avg_fn is not None: raise AssertionError( @@ -247,7 +249,7 @@ def forward(self, *args, **kwargs): """Forward pass.""" return self.module(*args, **kwargs) - def update_parameters(self, model: Module): + def update_parameters(self, model: Module) -> None: """Update model parameters.""" self_param = ( # pyrefly: ignore [bad-argument-type] @@ -329,7 +331,7 @@ def update_bn( loader: Iterable[Any], model: Module, device: Optional[Union[int, torch.device]] = None, -): +) -> None: r"""Update BatchNorm running_mean, running_var buffers in the model. It performs one pass over data in `loader` to estimate the activation @@ -434,7 +436,7 @@ def __init__( anneal_epochs=10, anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, - ): # noqa: D107 + ) -> None: # noqa: D107 swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True): group["swa_lr"] = swa_lr @@ -516,7 +518,7 @@ def get_lr(self): for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True) ] - def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): + def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None: self._anneal_strategy = anneal_strategy if anneal_strategy == "cos": self.anneal_func = self._cosine_anneal From 3feea296a59c2dfc1d2f4b7e0e5d3f61fd4bf7ea Mon Sep 17 00:00:00 2001 From: Mark Barnes Date: Thu, 6 Nov 2025 04:33:05 +0000 Subject: [PATCH 111/651] torch.fx: add debug-level logging to Interpreter.run_node (#117351) (#166622) ### Summary Adds a debug-level logging statement to torch.fx.Interpreter.run_node, as proposed in [#117351](https://github.com/pytorch/pytorch/issues/117351), to make FX graph execution traceable when debugging or instrumenting model transformations. When debug logging is enabled, each executed node emits a single structured log line formatted via `LazyString(lambda: n.format_node())`, deferring string construction unless logging is active. ### Example Output With `logging.DEBUG` enabled: ``` run_node x = x() run_node add = _operator.add(x, 1) run_node clamp = torch.clamp(add, min=0.0, max=5.0) run_node output = output(clamp) ``` With `logging.DEBUG` disabled no additional output is produced (unchanged default behavior). ### Test Plan Verified locally with Python 3.11 on macOS using a PyTorch build from source. - With `logging.DEBUG` enabled: each node emits a debug log via LazyString. - With `logging.DEBUG` disabled: no additional output. - Confirmed all `Interpreter` tests pass locally: `pytest test/test_fx.py -k "Interpreter"` Updated the example output to reflect the new `_format_fx_node` helper and inclusion of `kwargs`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166622 Approved by: https://github.com/aorenste --- torch/fx/interpreter.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 5ad1424c4e489..5b40e8a66147f 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs import inspect +import logging from contextlib import contextmanager from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.fx.traceback as fx_traceback -from torch._logging import trace_structured +from torch._logging import LazyString, trace_structured from torch.hub import tqdm from . import config @@ -21,10 +22,35 @@ if TYPE_CHECKING: from collections.abc import Iterator +log = logging.getLogger(__name__) __all__ = ["Interpreter", "Transformer"] +def _format_fx_node(n): + """ + Format a torch.fx.Node into a human-readable string for debug logging. + + Args: + n (torch.fx.Node): The FX node being executed. + + Returns: + str: A formatted string describing the node operation, including its + name, target, positional arguments, and keyword arguments. + """ + module_prefix = getattr(n.target, "__module__", "") + module_prefix = f"{module_prefix}." if module_prefix else "" + + # Handle positional and keyword arguments + args = ", ".join(map(str, n.args)) + kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items()) + joined = ", ".join(filter(None, [args, kwargs])) + + return ( + f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})" + ) + + @compatibility(is_backward_compatible=True) class Interpreter: """ @@ -261,6 +287,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ + log.debug("run_node %s", LazyString(lambda: _format_fx_node(n))) with self._set_current_node(n): args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) From eea951758fcb71ed544bee9f83e67913dec26aaf Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 4 Nov 2025 16:49:31 -0800 Subject: [PATCH 112/651] [dynamo, 3.14] disable dynamo cpython tests in 3.14 (again) (#167000) The previous PR was not enough to prevent errors caused by cpython dynamo tests in 3.14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167000 Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas --- test/run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index aa6a6d04cde3e..764b20dc9adc2 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1687,7 +1687,7 @@ def get_selected_tests(options) -> list[str]: ] ) - if sys.version_info[:2] < (3, 13): + if sys.version_info[:2] < (3, 13) or sys.version_info[:2] >= (3, 14): # Skip tests for older Python versions as they may use syntax or features # not supported in those versions options.exclude.extend( From 91337ae3ffd5e3a5204c9e47aeaa2d093710a46c Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 6 Nov 2025 04:57:01 +0000 Subject: [PATCH 113/651] [audio hash update] update the pinned audio hash (#167031) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167031 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 966f6bcfc0d94..14144f3c11e2d 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2 +ad5816f0eee1c873df1b7d371c69f1f811a89387 From f7b7f40a6fed52a7190301b8dfebc528b349c8d4 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:12 -0800 Subject: [PATCH 114/651] [user-streams] Enable stream ops to work in eager (#167141) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167141 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_streams.py | 27 ++++++++++++++++++++------- torch/_dynamo/variables/streams.py | 30 ++++++++++++++---------------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index b9a3855f6ddbb..6b7ad5ce0ce96 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -7,6 +7,10 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.graph_bytecode_inputs import ( + reset_user_object_tracking, + store_user_object_weakrefs, +) from torch._dynamo.testing import extract_graph, remove_trailing_space from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda @@ -441,13 +445,22 @@ def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream from torch.library import opcheck - sample_inputs = [ - (0, torch.device("cuda:0"), 1, torch.device("cuda:1")), - (2, torch.device("cuda:2"), 3, torch.device("cuda:1")), - ] - for args in sample_inputs: - opcheck(fork_stream, args) - opcheck(join_stream, args) + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + store_user_object_weakrefs(s0, s1) + + sample_inputs = [ + (0, 1), + (1, 0), + ] + for args in sample_inputs: + opcheck(fork_stream, args) + opcheck(join_stream, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() if __name__ == "__main__": diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fb5dd775bd636..bb9552186da6d 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -10,6 +10,7 @@ from .. import graph_break_hints from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented_v2 +from ..graph_bytecode_inputs import get_external_object_by_index from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import FxTracebackAnnotateVariable @@ -29,40 +30,37 @@ @custom_op("streams::fork", mutates_args=()) def fork_stream( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: - pass + stream = get_external_object_by_index(to_index) + assert isinstance(stream, torch.Stream), ( + f"fork_stream expects a stream object at index {to_index}" + ) + torch.accelerator.set_stream(stream) @fork_stream.register_fake def _( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: pass @custom_op("streams::join", mutates_args=()) -def join_stream( - from_index: int, - from_device: torch.device, - to_index: int, - to_device: torch.device, -) -> None: - pass +def join_stream(from_index: int, to_index: int) -> None: + stream = get_external_object_by_index(to_index) + assert isinstance(stream, torch.Stream), ( + f"join_stream expects a stream object at index {to_index}" + ) + torch.accelerator.set_stream(stream) @join_stream.register_fake def _( from_index: int, - from_device: torch.device, to_index: int, - to_device: torch.device, ) -> None: pass From 46b3f913b351ccf3696932afa7f31c1b1b8bfee7 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:13 -0800 Subject: [PATCH 115/651] [user-streams] Add record/wait ops (#167151) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167151 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167141 --- test/dynamo/test_streams.py | 26 +++++++++++++- torch/_dynamo/variables/streams.py | 58 ++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 6b7ad5ce0ce96..c21ab934e5b45 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -441,7 +441,7 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): ) @requires_cuda - def test_run_opcheck(self): + def test_run_opcheck_fork_join(self): from torch._dynamo.variables.streams import fork_stream, join_stream from torch.library import opcheck @@ -462,6 +462,30 @@ def test_run_opcheck(self): torch.accelerator.set_stream(original_stream) reset_user_object_tracking() + @requires_cuda + def test_run_opcheck_wait_record(self): + from torch._dynamo.variables.streams import record_event, wait_event + from torch.library import opcheck + + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + e0 = torch.Event() + e1 = torch.Event() + store_user_object_weakrefs(s0, s1, e0, e1) + + sample_inputs = [ + (2, 0), + (3, 1), + ] + for args in sample_inputs: + opcheck(wait_event, args) + opcheck(record_event, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index bb9552186da6d..98084cce28b27 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -28,16 +28,28 @@ Tensor = torch.Tensor +def _get_stream_by_index(index: int) -> torch.Stream: + stream = get_external_object_by_index(index) + assert isinstance(stream, torch.Stream), ( + f"Fork/join stream expected a stream object at index {index}" + ) + return stream + + +def _get_event_by_index(index: int) -> torch.Event: + event = get_external_object_by_index(index) + assert isinstance(event, torch.Event), ( + f"Record/wait event expected an event object at index {index}" + ) + return event + + @custom_op("streams::fork", mutates_args=()) def fork_stream( from_index: int, # kept to make stream transitions clearer to_index: int, ) -> None: - stream = get_external_object_by_index(to_index) - assert isinstance(stream, torch.Stream), ( - f"fork_stream expects a stream object at index {to_index}" - ) - torch.accelerator.set_stream(stream) + torch.accelerator.set_stream(_get_stream_by_index(to_index)) @fork_stream.register_fake @@ -50,11 +62,7 @@ def _( @custom_op("streams::join", mutates_args=()) def join_stream(from_index: int, to_index: int) -> None: - stream = get_external_object_by_index(to_index) - assert isinstance(stream, torch.Stream), ( - f"join_stream expects a stream object at index {to_index}" - ) - torch.accelerator.set_stream(stream) + torch.accelerator.set_stream(_get_stream_by_index(to_index)) @join_stream.register_fake @@ -65,6 +73,36 @@ def _( pass +@custom_op("streams::record_event", mutates_args=()) +def record_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.record_event(event) + + +@record_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +@custom_op("streams::wait_event", mutates_args=()) +def wait_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.wait_event(event) + + +@wait_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + class SymbolicStreamState: """Track the currently entered stream if any""" From 7b423c2d217452d7f65788dc3a9cb786f0b45769 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:13 -0800 Subject: [PATCH 116/651] [user-streams] Mark stream ops as side effectful (#167152) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167152 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167141, #167151 --- test/dynamo/test_streams.py | 16 ++++++++++++++++ torch/_dynamo/variables/streams.py | 14 +++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index c21ab934e5b45..1b81597977d77 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -486,6 +486,22 @@ def test_run_opcheck_wait_record(self): torch.accelerator.set_stream(original_stream) reset_user_object_tracking() + def test_is_marked_side_effectful(self): + self.assertIn( + torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.join.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.wait_event.default, + torch.fx.node._side_effectful_functions, + ) + self.assertIn( + torch.ops.streams.record_event.default, + torch.fx.node._side_effectful_functions, + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 98084cce28b27..65b4add4232f6 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -5,7 +5,7 @@ import torch from torch._dynamo.variables.dicts import ConstDictVariable from torch._dynamo.variables.lists import TupleVariable -from torch.fx import Proxy +from torch.fx import has_side_effect, Proxy from .. import graph_break_hints from ..bytecode_transformation import create_call_function @@ -60,6 +60,9 @@ def _( pass +has_side_effect(torch.ops.streams.fork.default) + + @custom_op("streams::join", mutates_args=()) def join_stream(from_index: int, to_index: int) -> None: torch.accelerator.set_stream(_get_stream_by_index(to_index)) @@ -73,6 +76,9 @@ def _( pass +has_side_effect(torch.ops.streams.join.default) + + @custom_op("streams::record_event", mutates_args=()) def record_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) @@ -88,6 +94,9 @@ def _( pass +has_side_effect(torch.ops.streams.record_event.default) + + @custom_op("streams::wait_event", mutates_args=()) def wait_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) @@ -103,6 +112,9 @@ def _( pass +has_side_effect(torch.ops.streams.wait_event.default) + + class SymbolicStreamState: """Track the currently entered stream if any""" From 8b2365094dbb531f9122b05fdf89553f6ccee03b Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 6 Nov 2025 05:59:05 +0000 Subject: [PATCH 117/651] Expose torch.compiler.config.force_disable_caches as a public API (#166699) Exposing this flag as some upstream frameworks (like vLLM) could benefit from knowing whether torch.compile caches are enabled or not to adjust their own caching behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166699 Approved by: https://github.com/oulgen, https://github.com/mlazos --- torch/compiler/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/compiler/config.py b/torch/compiler/config.py index e7578a57f2c0b..e507ddc18052e 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -35,6 +35,7 @@ "enable_cpp_symbolic_shape_guards", "wrap_top_frame", "reorderable_logging_functions", + "force_disable_caches", ] From 09d8953fb47de9a9209e409f6e72c7c8fa0ac0aa Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Nov 2025 11:19:35 -0800 Subject: [PATCH 118/651] Update `tensorpipe` submodule (#167108) To pick a single change https://github.com/pytorch/tensorpipe/commit/2b4cd91092d335a697416b2a3cb398283246849d that should fix compilation errors with clang-21 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167108 Approved by: https://github.com/Skylion007 --- third_party/tensorpipe | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/tensorpipe b/third_party/tensorpipe index af0118d13e52f..2b4cd91092d33 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit af0118d13e52f5a08841464a768e01a0bf3e3075 +Subproject commit 2b4cd91092d335a697416b2a3cb398283246849d From 9eebda944df1bac3c1668e0bf041b85473c80aaa Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 5 Nov 2025 12:04:49 -0800 Subject: [PATCH 119/651] make narrow_tensor_symint DDE-free (#166379) https://github.com/pytorch/pytorch/issues/158081 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379 Approved by: https://github.com/Lucaskabela ghstack dependencies: #166361 --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 +++++++++++++ test/test_proxy_tensor.py | 1 - 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index daa8a86da253b..0079a530b3d0e 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1784,8 +1784,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - int64_t st = start.item(); - return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); + c10::SymInt st = start.item().toSymInt(); + return at::narrow_symint(self, dim, std::move(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index b0dd1ff8fa75d..6cae42d8929da 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - xfail("narrow"), + skip("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..d3f9e415ff944 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,6 +4452,19 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_narrow_with_tensor_start(self): + @torch.compile(backend="inductor", fullgraph=True) + def f(x, start, end): + return torch.narrow(x, 0, start, end) + + x = torch.tensor( + [False], device="cuda:0" if torch.cuda.is_available() else "cpu" + ) + start = torch.tensor(0) + res = f(x, start, 0) + self.assertEqual(res.shape, torch.Size([0])) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b76895a0a91f3..0487995a2d1c5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,7 +1987,6 @@ def f(t): } only_fake_tensor_failures = { - xfail('narrow'), xfail('tensor_split'), } From ed4aa449b60f0b595e376575362efe739eae00a1 Mon Sep 17 00:00:00 2001 From: tianrengao Date: Thu, 6 Nov 2025 06:59:06 +0000 Subject: [PATCH 120/651] CustomOp Inline Fusion (#165952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Inline Fusion Support for Custom Op Autotuning -------------------------------------------------- This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations. ### Usage ```python def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: """Matrix multiply with k-way decomposition.""" ... @torch.library.custom_op("my_lib::matmul_relu", mutates_args={}) def custom_matmul_relu_dk( a: torch.Tensor, b: torch.Tensor, k_splits: int ) -> torch.Tensor: return torch.relu(decompose_k_implementation(a, b, k_splits)) register_custom_op_autotuning( custom_op=custom_matmul_relu_dk, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), CustomOpConfig(k_splits=8), CustomOpConfig(k_splits=32), CustomOpConfig(k_splits=64), ], name="decompose_k_autotuned", input_gen_fns={ "a": lambda fake: torch.randn_like(fake, device='cuda'), "b": lambda fake: torch.randn_like(fake, device='cuda'), } ) ``` ### How It Works Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency. During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens. Performance Results Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes): Screenshot 2025-11-04 at 12 43 11 AM Metric | Result -- | -- Average Speedup vs ATen | 1.28x Max Speedup vs ATen | 1.41x
The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile. image **Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled. -------------- ### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already) FP32: Screenshot 2025-11-04 at 12 05 08 AM FP16: Screenshot 2025-11-04 at 12 13 49 AM The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values. #### Usage: Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer. ```python # Define the matmul+relu function def matmul_relu(x, y): return torch.nn.functional.relu(torch.matmul(x, y)) # Compile with inline subgraph fusion enabled @torch.compile def compiled_matmul_relu(x, y): return matmul_relu(x, y) # Reset dynamo to ensure clean compilation torch._dynamo.reset() with config.patch( { "max_autotune": True, # CRITICAL: These two flags enable inline subgraph fusion "benchmark_epilogue_fusion": False, # Must be False for inline fusion! "enable_inline_subgraph_fusion": True, # Enable inline fusion } ): # Compile and run result = compiled_matmul_relu(a, b) torch.cuda.synchronize() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165952 Approved by: https://github.com/PaulZhang12, https://github.com/eellison --- test/inductor/test_custom_op_autotune.py | 178 +++++++---------------- torch/_inductor/codegen/subgraph.py | 25 +++- torch/_inductor/kernel/custom_op.py | 43 ++++-- torch/_inductor/lowering.py | 45 ++++-- torch/_inductor/select_algorithm.py | 11 ++ 5 files changed, 151 insertions(+), 151 deletions(-) diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index adc46a0f390a4..c148c69468902 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -216,115 +216,6 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}" ) - @skipIfXpu - def test_mlp_custom_op_autotune(self): - """Test MLP autotuning with method parameter controlling different decomposition variants. - - Validates parametric tuning where the same decomposition function uses different - algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights). - """ - test_op_name = f"test_lib::mlp_{id(self)}" - - def mlp_variants( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - """MLP implementation with different computational approaches controlled by method parameter.""" - - if method == 0: - gate_proj = torch.matmul(input_tensor, gate_weight) - up_proj = torch.matmul(input_tensor, up_weight) - gated = torch.relu(gate_proj) * up_proj - return torch.matmul(gated, down_weight) - - elif method == 1: - batch_shape = input_tensor.shape[:-1] - hidden_dim = input_tensor.shape[-1] - output_dim = down_weight.shape[-1] - - input_2d = input_tensor.view(-1, hidden_dim) - - gate_proj = torch.mm(input_2d, gate_weight) - up_proj = torch.mm(input_2d, up_weight) - - gated = torch.relu(gate_proj) * up_proj - output_2d = torch.mm(gated, down_weight) - - return output_2d.view(*batch_shape, output_dim) - - @torch.library.custom_op(test_op_name, mutates_args=()) - def test_mlp_op( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - return mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=method - ) - - @test_mlp_op.register_fake - def _( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ): - return torch.empty( - input_tensor.shape[:-1] + (down_weight.shape[-1],), - device=input_tensor.device, - dtype=input_tensor.dtype, - ) - - # Use explicit config with method parameter as tuning knob - register_custom_op_autotuning( - test_mlp_op, - configs=[ - CustomOpConfig(method=0), - CustomOpConfig(method=1), - ], - name="test_mlp_autotuned", - input_gen_fns={ - "input_tensor": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.1, - "gate_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "up_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "down_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - }, - ) - - # Create test inputs - input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs() - - # Test that all method variants produce numerically equivalent results - expected = mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=0 - ) - - # Test autotuning - self._run_autotune_test( - test_mlp_op, - (input_tensor, gate_weight, up_weight, down_weight), - expected, - "MLP", - ) - def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] @@ -335,12 +226,12 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): @skipIfXpu def test_decompose_k_custom_op_autotune(self): - """Test decompose_k autotuning with parametric tuning for k_splits values. + """Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale). - Validates numerical parameter sweep where k_splits controls how the K dimension - is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]). + Validates that the custom op encapsulates the entire fused operation with parametric + tuning for k_splits values controlling how the K dimension is decomposed. """ - test_op_name = f"test_lib::decompose_k_{id(self)}" + test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}" def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 @@ -363,19 +254,23 @@ def decompose_k_implementation( return torch.sum(result, dim=0) # [m, n] @torch.library.custom_op(test_op_name, mutates_args=()) - def test_decompose_k_op( - a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 + def matmul_relu_epilogue_op( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: - """Matrix multiply with k-way decomposition - custom op using the decomposition.""" - return decompose_k_implementation(a, b, k_splits) - - @test_decompose_k_op.register_fake - def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): + """Matmul with decompose_k + bias + relu + scale (complete epilogue fusion).""" + matmul_result = decompose_k_implementation(a, b, k_splits) + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + @matmul_relu_epilogue_op.register_fake + def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) - # Register autotuning with different k_splits values using decomposition function + # Register autotuning with different k_splits values register_custom_op_autotuning( - test_decompose_k_op, + matmul_relu_epilogue_op, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), @@ -385,7 +280,7 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): CustomOpConfig(k_splits=64), CustomOpConfig(k_splits=128), ], - name="test_decompose_k_autotuned", + name="matmul_relu_epilogue_autotuned", input_gen_fns={ "a": lambda fake_tensor: torch.randn_like( fake_tensor, device=self.device @@ -395,12 +290,45 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): fake_tensor, device=self.device ) * 0.1, + "bias": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, }, ) + # Create test inputs a, b = self._create_decompose_k_inputs() - expected = a @ b - self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK") + bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1 + + # Compile the model using the custom op + @torch.compile + def test_model(a, b, bias): + return matmul_relu_epilogue_op(a, b, bias) + + torch._dynamo.reset() + + with config.patch( + max_autotune=True, + benchmark_fusion=True, + ): + compiled_result = test_model(a, b, bias) + + def reference_model(a, b, bias): + matmul_result = a @ b + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + expected = reference_model(a, b, bias) + + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + ) @skipIfXpu def test_multi_parameter_tuning(self): diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 4cc3f0ef282a8..1c1f0f1c9cd2c 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -24,6 +24,22 @@ log = logging.getLogger(__name__) +def inline_subgraph_to_ir_nodes( + gm: torch.fx.GraphModule, inputs: list[Any], name: str +) -> Any: + """Inline a subgraph by converting its FX operations to individual IR nodes. + + This converts a subgraph to multiple ComputedBuffer nodes (fusable), + enabling epilogue fusion with subsequent operations. + + Returns: + TensorBox containing the final operation result as individual IR nodes + """ + from torch._inductor.lowering import process_subgraph_nodes + + return process_subgraph_nodes(gm, inputs) + + class SubgraphChoiceCaller(ir.ChoiceCaller): """ Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary @@ -261,7 +277,14 @@ def make_fx_graph( # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs from torch.fx.experimental.proxy_tensor import make_fx - return make_fx(functools.partial(decomp, **decomp_kwargs))(*args) + from ..decomposition import select_decomp_table + + decomposition_table = select_decomp_table() + + return make_fx( + functools.partial(decomp, **decomp_kwargs), + decomposition_table=decomposition_table, + )(*args) # Generate descriptive name for this variant variant_name = self._generate_variant_name(decomp, decomp_kwargs) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index d35309c01d07c..23878f757cc5e 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union import torch +from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -158,7 +159,6 @@ def _adapt_user_input_gen_fns( Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. """ - from torch._inductor import config name_to_index = {name: i for i, name in enumerate(arg_names)} index_based_fns = {} @@ -238,6 +238,7 @@ def autotune_custom_op( This function generates multiple implementation choices for a custom operation and uses Inductor's autotuning system to select the best performing variant at runtime. + After selecting the best choice, applies inline fusion if the winning choice has a graph. Args: name: Unique identifier for the autotuning operation @@ -320,14 +321,34 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) - return autotune_select_algorithm( + # Run autotuning and get both result and winning choice + selected_result, winning_choice = autotune_select_algorithm( name=name, choices=choices, input_nodes=list(inputs), layout=choices[0].layout, input_gen_fns=input_gen_fns, + return_choice=True, ) + # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) + if winning_choice.gm is not None: + log.debug( + "Inlining winning choice: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes + + return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name) + + log.debug( + "Winning choice does not support inlining: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + return selected_result + def register_custom_op_autotuning( custom_op: torch._library.custom_ops.CustomOpDef, @@ -360,7 +381,7 @@ def my_attention(query, key, value, head_dim=32): "query": lambda fake: torch.randn_like(fake, device='cuda'), "key": lambda fake: torch.randn_like(fake, device='cuda'), "value": lambda fake: torch.randn_like(fake, device='cuda'), - } + }, ) """ from torch._library.custom_ops import CustomOpDef @@ -378,12 +399,12 @@ def my_attention(query, key, value, head_dim=32): raise TypeError(f"configs must be a list or tuple, got {type(configs)}") processed_configs = [] - for config in configs: - if isinstance(config, CustomOpConfig): - processed_configs.append(config) + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + processed_configs.append(cfg) else: raise TypeError( - f"Each config must be a CustomOpConfig object, got {type(config)}" + f"Each config must be a CustomOpConfig object, got {type(cfg)}" ) if not processed_configs: @@ -402,14 +423,12 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: decompositions = [] non_tensor_args = [] - for config in processed_configs: - decomp = config.get_decomposition(default_impl=default_impl) + for cfg in processed_configs: + decomp = cfg.get_decomposition(default_impl=default_impl) decompositions.append(decomp) # Merge config params with runtime kwargs (runtime takes precedence) - merged_kwargs = _merge_config_and_runtime_kwargs( - config.params, runtime_kwargs - ) + merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs) non_tensor_args.append(merged_kwargs) result = autotune_custom_op( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index cc13f79909014..f6ad1028ca12d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7307,6 +7307,35 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): return list(map(TensorBox.create, result)) # type: ignore[call-overload] +def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): + """Process nodes from a FX graph by executing them through V.graph. + + This is a common pattern for executing a subgraph's nodes: + - Placeholder nodes are mapped to the provided args + - Output nodes return their result + - Other nodes are executed via V.graph.run_node + + """ + output = None + + for i, node in enumerate(graph_module.graph.nodes): + if node.op == "placeholder": + assert node not in V.graph.env + V.graph.env[node] = args[i] + continue + elif node.op == "output": + output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs) + else: + assert node not in V.graph.env + V.graph.env[node] = V.graph.run_node(node) + + if output is None: + raise RuntimeError("No output node found in graph") + + return output + + # Import the control_deps_op HOP for lowering from torch._inductor.fx_passes.control_dependencies import control_deps @@ -7334,21 +7363,11 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args): arg_offset = 2 # first two args (additional_deps, subgraph) assert len(args) + arg_offset == len(original_args) - output = None - operation_len = len(V.graph.operations) assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args) - for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): - if node.op == "placeholder": - assert node not in V.graph.env - V.graph.env[node] = args[i] - continue - elif node.op == "output": - args, kwargs = V.graph.fetch_args_kwargs_from_env(node) - output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) - else: - assert node not in V.graph.env - V.graph.env[node] = V.graph.run_node(node) + + # Process subgraph nodes using the shared helper + output = process_subgraph_nodes(subgraph_fn.graph_module, list(args)) assert output is not None and additional_deps diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 41021b0fc8ed1..e1d36d54e844a 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2145,6 +2145,8 @@ def __init__( # There is no src hash for ExternKernelChoice in the traditional sense # so we indicate this by returning None self.src_hash = None + # By default GraphModule is None for extern kernels if not set + self.gm = None def to_callable(self): return getattr(extern_kernels, self.name) @@ -2317,6 +2319,7 @@ def __init__( self.choice = choice self.kwargs = kwargs or {} self.has_out_variant = has_out_variant + self.gm = choice.gm def __str__(self) -> str: return f"ExternKernelCaller({self.choice.call_name()})" @@ -2700,6 +2703,7 @@ def __call__( precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, best_config_future=None, + return_choice=False, # TODO: return_choice is temporary and will be refactored soon ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2973,18 +2977,25 @@ def get_timings(hint_override: Optional[int] = None): "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", node, ) + if return_choice: + return node, choice return node node = choices[0].output_node() + choice = choices[0] log.debug( "Autotuning returned empty timings, falling back to first choice: %s", node, ) + if return_choice: + return node, choice return node # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() log.debug("Autotuning selected choice: %s", node) + if return_choice: + return node, choice return node def make_precompile_fn( From a51208c656fb3e9a8b091a4d181f9a9cda783c04 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 6 Nov 2025 08:02:53 +0000 Subject: [PATCH 121/651] Check cluster_dims attribute exists before access (#167187) Error in Helion CI's AMD job: https://github.com/pytorch/helion/actions/runs/19118581048/job/54633730633 ``` > (binary.metadata.num_ctas, *binary.metadata.cluster_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ if hasattr(binary, "metadata") else () ) ), "function": get_first_attr(binary, "function", "cu_function"), "runner": get_first_attr(binary, "run", "c_wrapper"), "math": math_lib, "torch": torch_lib, "triton": triton_lib, } E torch._inductor.exc.InductorError: AttributeError: 'KernelMetadata' object has no attribute 'cluster_dims' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167187 Approved by: https://github.com/oulgen --- torch/_inductor/runtime/triton_heuristics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index cdecd50927024..2e0a0dba9092e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1851,6 +1851,8 @@ def make_launcher(self) -> LauncherType: else ( (binary.metadata.num_ctas, *binary.metadata.cluster_dims) if hasattr(binary, "metadata") + and hasattr(binary.metadata, "num_ctas") + and hasattr(binary.metadata, "cluster_dims") else () ) ), From c724f0097ddcf2a1dffb928ad18eafed6005595e Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 12:13:47 +0000 Subject: [PATCH 122/651] [2/N] Use `key in dict` for existence checks (#167174) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167174 Approved by: https://github.com/mlazos --- test/cpp/api/init_baseline.py | 2 +- test/cpp/api/optim_baseline.py | 2 +- .../checkpoint/test_hf_safetensor_e2e.py | 4 ++-- .../fsdp/test_fsdp_mixed_precision.py | 2 +- test/distributed/test_c10d_common.py | 6 ++---- test/distributed/test_local_tensor.py | 20 +++++++++---------- test/dynamo/test_subclasses.py | 2 +- test/functorch/xfail_suggester.py | 2 +- test/inductor/test_compiled_optimizers.py | 2 +- test/profiler/test_profiler.py | 8 ++------ .../core/test_quantized_module.py | 4 ++-- .../quantization/core/test_workflow_module.py | 2 +- .../pt2e/test_x86inductor_quantizer.py | 2 +- test/test_fx.py | 2 +- test/test_testing.py | 2 +- torch/ao/ns/fx/pattern_utils.py | 2 +- .../ao/pruning/sparsifier/base_sparsifier.py | 2 +- torch/ao/quantization/_equalize.py | 2 +- torch/ao/quantization/fx/_equalize.py | 2 +- .../quantization/fx/_model_report/detector.py | 2 +- torch/ao/quantization/fx/convert.py | 2 +- torch/ao/quantization/fx/prepare.py | 4 ++-- .../quantization/fx/qconfig_mapping_utils.py | 4 ++-- torch/ao/quantization/fx/utils.py | 2 +- .../quantization/pt2e/port_metadata_pass.py | 2 +- torch/ao/quantization/pt2e/prepare.py | 2 +- torch/ao/quantization/qconfig_mapping.py | 2 +- torch/ao/quantization/quantize_jit.py | 4 ++-- .../quantizer/x86_inductor_quantizer.py | 4 +--- torch/fx/experimental/unify_refinements.py | 6 +++--- torch/fx/graph.py | 2 +- torch/fx/passes/runtime_assert.py | 6 ++---- torch/fx/passes/splitter_base.py | 4 ++-- torch/fx/passes/utils/source_matcher_utils.py | 4 ++-- torch/jit/_recursive.py | 2 +- torch/jit/_script.py | 2 +- torch/nn/modules/module.py | 2 +- .../_internal/exporter/_dynamic_shapes.py | 2 +- torch/profiler/_memory_profiler.py | 4 ++-- torch/profiler/_utils.py | 6 +++--- torch/utils/_config_module.py | 2 +- torch/utils/collect_env.py | 4 ++-- torch/utils/data/datapipes/iter/callable.py | 2 +- torch/utils/data/datapipes/iter/grouping.py | 2 +- torch/utils/tensorboard/summary.py | 2 +- torch/utils/tensorboard/writer.py | 2 +- torchgen/gen_backend_stubs.py | 3 +-- 47 files changed, 72 insertions(+), 83 deletions(-) diff --git a/test/cpp/api/init_baseline.py b/test/cpp/api/init_baseline.py index 47b202e86311d..4042657b4d5c3 100644 --- a/test/cpp/api/init_baseline.py +++ b/test/cpp/api/init_baseline.py @@ -64,7 +64,7 @@ def run(initializer): def main(): initializer_parameter_map = {} - for initializer in INITIALIZERS.keys(): + for initializer in INITIALIZERS: sys.stderr.write(f"Evaluating {initializer} ...\n") initializer_parameter_map[initializer] = run(initializer) diff --git a/test/cpp/api/optim_baseline.py b/test/cpp/api/optim_baseline.py index 7e278d4e42086..e1a3c91b7128f 100644 --- a/test/cpp/api/optim_baseline.py +++ b/test/cpp/api/optim_baseline.py @@ -130,7 +130,7 @@ def main(): options = parser.parse_args() optimizer_parameter_map = {} - for optimizer in OPTIMIZERS.keys(): + for optimizer in OPTIMIZERS: sys.stderr.write(f"Evaluating {optimizer} ...\n") optimizer_parameter_map[optimizer] = run( optimizer, options.iterations, options.sample_every diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index f0316fde9f2c5..1aaaf645c58df 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -208,7 +208,7 @@ def test_quantized_checkpoint_loading(self) -> None: # Create model.safetensors.index.json with weight mapping weight_map = {} - for key in quantized_checkpoint.keys(): + for key in quantized_checkpoint: weight_map[key] = "model.safetensors" index_data = { @@ -245,7 +245,7 @@ def test_quantized_checkpoint_loading(self) -> None: sorted(original_tensors.keys()), sorted(state_dict_to_load.keys()) ) - for tensor_name in original_tensors.keys(): + for tensor_name in original_tensors: original = original_tensors[tensor_name] loaded = state_dict_to_load[tensor_name] diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index dee38d0403467..b4532a86e3052 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -498,7 +498,7 @@ def _run_test_mixed_precision_e2e( for name, tensor in state_dict.items(): # Parameters and buffers are checkpointed in their # original dtypes, which may be different. - if name in named_buffers.keys(): + if name in named_buffers: self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) else: self.assertEqual( diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 985e2d5f151a2..2a1cb2b5580cb 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1189,9 +1189,7 @@ def _test_sequence_num_incremented(self, process_group, ranks): self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { - rank_to_seq_num[i] - for i in rank_to_seq_num.keys() - if i not in [0, 2] + rank_to_seq_num[i] for i in rank_to_seq_num if i not in [0, 2] } self.assertEqual(len(expected_same), 1) self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1]) @@ -1558,7 +1556,7 @@ def test_debug_level(self): } invalid_debug_modes = ["foo", 0, 1, -1] - for mode in mapping.keys(): + for mode in mapping: os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) dist.set_debug_level_from_env() set_debug_mode = dist.get_debug_level() diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index c58ddf0f82ba7..fa081243c2816 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -128,14 +128,14 @@ def test_basic_arithmetic_operations(self): self.assertEqual(len(result_add._local_tensors), 2) # Verify the operation was applied to each local tensor - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + identical_local_tensors[rank] self.assertEqual(result_add._local_tensors[rank], expected) # Test multiplication result_mul = lt1 * 2.0 self.assertIsInstance(result_mul, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] * 2.0 self.assertEqual(result_mul._local_tensors[rank], expected) @@ -163,7 +163,7 @@ def test_mixed_operations_with_regular_tensors(self): result = lt + regular_tensor self.assertIsInstance(result, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + regular_tensor self.assertEqual(result._local_tensors[rank], expected) @@ -212,14 +212,14 @@ def test_collectives_within_local_tensor_mode(self): dist.all_reduce(lt_sum, group=fake_pg) expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]]) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test broadcast within mode lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.broadcast(lt_broadcast, src=0, group=fake_pg) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0]) # Test that regular operations still work @@ -293,21 +293,21 @@ def test_collective_reduction_operations(self): lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg) expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test MAX reduction lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg) expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_max._local_tensors[rank], expected_max) # Test MIN reduction lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg) expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_min._local_tensors[rank], expected_min) def test_all_reduce_collective(self): @@ -328,7 +328,7 @@ def test_all_reduce_collective(self): # Verify all ranks have the sum of all tensors (after adding 1 to each) expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]]) - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) def test_broadcast_collective(self): @@ -348,7 +348,7 @@ def test_broadcast_collective(self): # Verify all ranks have rank 1's original tensor expected_broadcast = different_tensors[1] - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast) def test_all_gather_collective(self): diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 39a0dc628baec..5d31fa28880a6 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -4036,7 +4036,7 @@ def backend(gm, args): @parametrize( "nt_view_name", - [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], + [k for k in VIEW_TEST_CASES if k != "subclass_dense_subclass_dense"], ) def test_inputs_to_compiled_fn_are_views(self, nt_view_name): self._input_view_test(nt_view_name) diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index cab6b018d5782..8efd8dfe398f2 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -73,7 +73,7 @@ def parse_namespace(base): "sparse_": "sparse", "special_": "special", } - for heading in mappings.keys(): + for heading in mappings: if base.startswith(heading): return mappings[heading], base[len(heading) :] return None, base diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index df93e7e1e4d61..ebee5149476b8 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -320,7 +320,7 @@ def build_opt_kwarg_db(): continue if has_tensor_lr: - for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys(): + for scheduler_cls in LR_SCHEDULER_TO_KWARGS: name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}" compiled_opt_db.append( ( diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 25fb60674e59e..fc128ba61907a 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -916,8 +916,7 @@ def judge(expected_event_count, prof): ) for key, count in expected_event_count.items(): self.assertTrue( - (key in actual_event_count.keys()) - and (count == actual_event_count[key]) + (key in actual_event_count) and (count == actual_event_count[key]) ) with _profile(use_kineto=kineto_available()) as prof: @@ -1406,10 +1405,7 @@ def test_profiler_fwd_bwd_link(self): s_ts_2 = flow_s_to_ts[2] f_ts_2 = flow_f_to_ts[2] self.assertTrue( - all( - ts in ts_to_name.keys() - for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2] - ) + all(ts in ts_to_name for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]) ) self.assertTrue( ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits" diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index b2b2b402327ad..f2cdbfd2d6316 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -1840,7 +1840,7 @@ def test_cell_api(self, dtype): 'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic, 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): # fp16 dynamic quant is not supported for qnnpack or onednn kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype} @@ -1903,7 +1903,7 @@ def test_rnn_cell(self): 'RNNTanh': nnqr.RNNCell, 'RNNReLU': nnqr.RNNCell} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias} if rnn_type == 'RNNReLU': kwargs['nonlinearity'] = "relu" diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 9ea8d38828a63..93993fe33a49c 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -650,7 +650,7 @@ def test_record_observer(self): observer_dict = {} _get_observer_dict(model, observer_dict) - self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), + self.assertTrue('fc1.module.activation_post_process' in observer_dict, 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index dfd591cb9419c..5b9aa34158b5e 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2016,7 +2016,7 @@ def test_qat_conv2d_unary(self): } with override_quantized_engine("x86"): - for unary_op in unary_map.keys(): + for unary_op in unary_map: m = TestHelperModules.Conv2dUnaryModule( unary_map[unary_op][0], with_bn=True ) diff --git a/test/test_fx.py b/test/test_fx.py index f728187fd85f5..3ad21e64c8ce2 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4746,7 +4746,7 @@ def check_symbols_have_bc_designation(m, seen): check_symbols_have_bc_designation(torch.fx.passes, set()) non_back_compat_strs = [ - torch.typename(obj) for obj in non_back_compat_objects.keys() + torch.typename(obj) for obj in non_back_compat_objects ] # Only want objects in torch.fx non_back_compat_strs = [ diff --git a/test/test_testing.py b/test/test_testing.py index c660eb83b8042..09887be17c47a 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -510,7 +510,7 @@ def test_trivial_passing_test(self, device): # Test without setting env var should run everything. env = dict(os.environ) for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]: - if k in env.keys(): + if k in env: del env[k] _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii')) diff --git a/torch/ao/ns/fx/pattern_utils.py b/torch/ao/ns/fx/pattern_utils.py index c4d231e713b20..d10fdd39da908 100644 --- a/torch/ao/ns/fx/pattern_utils.py +++ b/torch/ao/ns/fx/pattern_utils.py @@ -72,7 +72,7 @@ def get_reversed_fusions() -> list[tuple[NSFusionType, int]]: all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) default_base_op_idx = 0 - for quant_pattern in all_quant_patterns.keys(): + for quant_pattern in all_quant_patterns: # TODO: this is a temporary hack to flatten the patterns from quantization so # that it works with the ns matcher function, maybe we should use `_is_match` # in torch.ao.quantization.fx.match_utils to match the patterns diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 14764c77cc604..59f6a46fe1350 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -196,7 +196,7 @@ def prepare(self, model, config): # check that whatever was put into local_args agrees with what was obtained # from tensor_fqn - for key in info_from_tensor_fqn.keys(): + for key in info_from_tensor_fqn: if key in local_args: if not ( info_from_tensor_fqn[key] == local_args[key] diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index a78dd307fc6d6..e4ff327f285aa 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -270,7 +270,7 @@ def converged(curr_modules, prev_modules, threshold=1e-4): summed_norms = torch.tensor(0.0) if None in prev_modules.values(): return False - for name in curr_modules.keys(): + for name in curr_modules: curr_weight = get_module_weight(curr_modules[name]) prev_weight = get_module_weight(prev_modules[name]) diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index b8809c1c60871..6c8c32b992ed4 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -350,7 +350,7 @@ def get_op_node_and_weight_eq_obs( # Find the op node that comes directly after the input equalization observer op_node = None - for user in input_eq_obs_node.users.keys(): + for user in input_eq_obs_node.users: if node_supports_equalization(user, modules): op_node = user break diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 993a6c41f176f..0a48bbbaaee90 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -743,7 +743,7 @@ def generate_detector_report( # Populates the string based report with the information from module_dynamic_static_info # Compiles the complete report by appending relevant formatted strings - for module_fqn in module_dynamic_static_info.keys(): + for module_fqn in module_dynamic_static_info: # there is at least 1 module for suggestion modules_added = True module_info = module_dynamic_static_info[module_fqn] diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 08ae102f69f41..06936e5327bce 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -683,7 +683,7 @@ def _maybe_get_observer_for_node( If the node is observed, return the observer instance. Otherwise, return None. """ - for maybe_obs_node in node.users.keys(): + for maybe_obs_node in node.users: if maybe_obs_node.op == "call_module": maybe_obs = modules[str(maybe_obs_node.target)] if _is_activation_post_process(maybe_obs): diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 0c05e6499901d..8351dbedd07d7 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -950,7 +950,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # we should remove this # removing this means we insert one observer for each use, even if they # have the same dtype, we can have an extra pass that removes the extra observers - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if maybe_obs_node.op == "call_module": maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] if ( @@ -1440,7 +1440,7 @@ def _maybe_make_input_output_share_observers( setattr(named_modules[parent_name], name, obs_mod_to_use) # set the output observer node to use that module - for output_obs_node in node.users.keys(): + for output_obs_node in node.users: if not _is_activation_post_process_node(output_obs_node, named_modules): raise AssertionError( "output_obs_node must be an activation post process node" diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 74f90505ea2af..951ca66703f47 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -206,7 +206,7 @@ def _check_is_valid_config_dict( `config_dict`: dictionary whose keys we want to check """ - for k in config_dict.keys(): + for k in config_dict: if k not in allowed_keys: raise ValueError( "Expected " @@ -250,7 +250,7 @@ def _compare_prepare_convert_qconfig_mappings( _MODULE_NAME_REGEX_DICT_KEY, ] for i in range(len(prepare_dicts)): - for name in prepare_dicts[i].keys(): + for name in prepare_dicts[i]: if name not in convert_dicts[i]: raise AssertionError( f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 3e2afaaa1d9f3..9f76f2a328df1 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -442,7 +442,7 @@ def maybe_get_next_module( target_functional_type: Functional type that we want to check """ - for user in node.users.keys(): + for user in node.users: if ( user.op == "call_module" and target_module_type is not None diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index aab4c435c872f..8e768592826e4 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -66,7 +66,7 @@ def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: continue if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: return n - for k in n.users.keys(): + for k in n.users: queue.append(k) return None diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 9f7767101aba6..c15e7878eb2b7 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -391,7 +391,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # instead of inserting new observers we will have: # conv1 -> obs1 -> existing_obs -> conv2 # \ -> conv3 - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if not _is_activation_post_process_node(maybe_obs_node, named_modules): continue maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 10111d4ab8a2a..2bfce5d858cc4 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -187,7 +187,7 @@ def _get_default_qconfig_mapping_with_default_qconfig( else: qconfig_mapping = get_default_qconfig_mapping(backend) qconfig_mapping.set_global(default_qconfig) - for pattern in qconfig_mapping.object_type_qconfigs.keys(): + for pattern in qconfig_mapping.object_type_qconfigs: if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: qconfig_mapping.set_object_type(pattern, default_qconfig) return qconfig_mapping diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 79f8db1a792fc..ec4caab1edcd0 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -68,7 +68,7 @@ def fuse_conv_bn_jit(model, inplace=False): def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): _check_is_script_module(model) _check_forward_method(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) model = fuse_conv_bn_jit(model, inplace) @@ -90,7 +90,7 @@ def _prepare_ondevice_jit( quant_type=QuantType.STATIC, ): _check_is_script_module(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) method_graph = model._c._get_method(method_name).graph diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index b10163d4b1e50..816f48fd6267a 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1361,9 +1361,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): elif ( node.target is torch.ops.aten.flatten.using_ints and len(node.users) > 0 - and not any( - user.target in quantizable_ops for user in node.users.keys() - ) + and not any(user.target in quantizable_ops for user in node.users) ): # Recipe of flatten: check if any users of flatten node are quantizable ops or not return diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index bab662e0655a2..efafb146179a6 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -61,7 +61,7 @@ def substitute_solution_one_type(mapping, t): Apply the most general unifier to a type """ if isinstance(t, Var): - if t in mapping.keys(): + if t in mapping: return mapping[t] else: return t @@ -69,7 +69,7 @@ def substitute_solution_one_type(mapping, t): elif isinstance(t, TensorType): new_type = [] for typ in t.__args__: - if typ in mapping.keys(): + if typ in mapping: new_type.append(mapping[typ]) else: new_type.append(typ) @@ -102,7 +102,7 @@ def substitute_all_types(graph, mapping): flag = False for k in mapping: old_mapping_val = mapping[k] - if mapping[k] in mapping.keys(): + if mapping[k] in mapping: new_key = mapping[k] mapping[k] = mapping[new_key] if old_mapping_val != mapping[k]: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index d924eac24d3c2..d8cfa42472b49 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1145,7 +1145,7 @@ def find_nodes(self, *, op: str, target: Optional["Target"] = None): return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module - return [node for node in self.table[(op, None)].keys() if node.target == target] + return [node for node in self.table[(op, None)] if node.target == target] @compatibility(is_backward_compatible=True) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 58aa801062824..1d3b0b33e7bce 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -373,11 +373,9 @@ def has_new_untracked_symbols(): shape_env, node.meta.get("unbacked_bindings", {}) ) - assert resolved_unbacked_bindings is not None - def has_new_unbacked_bindings(): - # pyrefly: ignore [missing-attribute] - for key in resolved_unbacked_bindings.keys(): + assert resolved_unbacked_bindings is not None + for key in resolved_unbacked_bindings: if key not in expr_to_proxy: return True return False diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 6cf708a619069..8d90f9d55cfdb 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -204,7 +204,7 @@ def to_dict(self): Create dict dump on all events. """ ret: dict[str, list[str]] = {} - for name in self.node_events.keys(): + for name in self.node_events: ret[name] = [] for idx in self.node_events.get(name, []): event = self.events[idx] @@ -218,7 +218,7 @@ def print_all(self, writer=None): """ if not writer: writer = self.writer - for name in self.node_events.keys(): + for name in self.node_events: writer(f"Node: {name}:") self.print_node(name, recursive=False, tab=" ", writer=writer) diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 043c65e6b77d2..82259b8a36ab7 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -113,7 +113,7 @@ def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: # get_attr nodes won't be output nodes continue - for user in node.users.keys(): + for user in node.users: if user not in nodes: output_nodes.add(node) @@ -157,7 +157,7 @@ def check_subgraphs_connected( """ for node in reversed(subgraph1.nodes): - for user in node.users.keys(): + for user in node.users: if user in subgraph2.nodes: return True return False diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 3a2b3ef8b6001..343871b1f94a2 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -574,7 +574,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. - for name in concrete_type.get_attributes().keys(): + for name in concrete_type.get_attributes(): orig_value = getattr(nn_module, name) orig_value = ( orig_value.value diff --git a/torch/jit/_script.py b/torch/jit/_script.py index a8bb3ba9bd8f5..46e6f47534108 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -856,7 +856,7 @@ def __setattr__(self, attr, value): self._c.setattr(attr, value) elif ( hasattr(self, "_concrete_type") - and attr in self._concrete_type.get_constants().keys() + and attr in self._concrete_type.get_constants() ): # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. # We should encode constants as class type attributes (or something) so it persists across save/load. diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 10a240e3a9cf7..33bf35a1d852a 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2521,7 +2521,7 @@ def _load_from_state_dict( unexpected_keys.append(extra_state_key) if strict: - for key in state_dict.keys(): + for key in state_dict: if key.startswith(prefix) and key != extra_state_key: input_name = key[len(prefix) :].split(".", 1) # Must be Module if it have attributes diff --git a/torch/onnx/_internal/exporter/_dynamic_shapes.py b/torch/onnx/_internal/exporter/_dynamic_shapes.py index e128ecf74e9e4..888db138736fb 100644 --- a/torch/onnx/_internal/exporter/_dynamic_shapes.py +++ b/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -67,7 +67,7 @@ def from_dynamic_axes_to_dynamic_shapes( # output names are not needed for dynamic_shapes continue if isinstance(axes, dict): - if any(not isinstance(k, int) for k in axes.keys()): + if any(not isinstance(k, int) for k in axes): raise ValueError( "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." ) diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 3f21ce81171d7..dfa83f7467cd6 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -711,7 +711,7 @@ def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]: events: list[tuple[int, Action, TensorAndID]] = [ (-1, Action.PREEXISTING, (key, version)) - for key, version in snapshot.keys() + for key, version in snapshot if (key, True) not in allocation_times and version == 0 ] @@ -938,7 +938,7 @@ def _set_parameters_using_data_flow(self) -> None: parameter_keys = {key.id for key, _ in candidate_parameters} parameter_keys &= self._any_version_depends_on_gradient() - for key, _ in snapshot.keys(): + for key, _ in snapshot: if key.id in parameter_keys: self._categories.set_by_id(key, Category.PARAMETER) diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c575b06509e5 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -103,7 +103,7 @@ def __init__(self, prof: profile) -> None: self.metrics: dict[EventKey, EventMetrics] = {} self.compute_self_time() self.event_keys = sorted( - (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + self.metrics.keys(), key=lambda x: x.event.start_time_ns ) self.events = [e.event for e in self.event_keys] self.cuda_events: list[_KinetoEvent] = [] @@ -265,7 +265,7 @@ def compute_idle_time(self) -> None: idle_intervals.append(Interval(idle_start, data_point.start)) idle = False - event_list = [e.event for e in self.metrics.keys()] + event_list = [e.event for e in self.metrics] for event in event_list: self.metrics[EventKey(event)].idle_time_ns = EventKey( event @@ -316,7 +316,7 @@ def rank_events(self, length): # Filter out events that are not in the decrease interval event_list = [ event - for event in self.metrics.keys() + for event in self.metrics if event.intervals_overlap(decrease_interval) ] if event_list: diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 12ba497efd79c..f302a10b8338e 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -692,7 +692,7 @@ def __enter__(self) -> None: raise AssertionError( "prior should be empty when entering ConfigPatch" ) - for key in self.changes.keys(): + for key in self.changes: # KeyError on invalid entry prior[key] = config.__getattr__(key) for k, v in self.changes.items(): diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 3b8b62cfde6d4..a643314f3b9cd 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -803,14 +803,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): def replace_nones(dct, replacement="Could not collect"): - for key in dct.keys(): + for key in dct: if dct[key] is not None: continue dct[key] = replacement return dct def replace_bools(dct, true="Yes", false="No"): - for key in dct.keys(): + for key in dct: if dct[key] is True: dct[key] = true elif dct[key] is False: diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 1ce1c9c07196c..2e3bb18c80bb7 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -149,7 +149,7 @@ def _collate_helper(conversion, item): tuple_names: list = [] tuple_values: list = [] - for name in conversion.keys(): + for name in conversion: if name not in columns_name: raise RuntimeError("Conversion keys mismatch") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 865feb9953e35..a289bdb5e0949 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -234,7 +234,7 @@ def _remove_biggest_key(self): biggest_key = None biggest_size = 0 result_to_yield = None - for findkey in self.buffer_elements.keys(): + for findkey in self.buffer_elements: if len(self.buffer_elements[findkey]) > biggest_size: biggest_size = len(self.buffer_elements[findkey]) biggest_key = findkey diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index f36382cb42e16..1b6a2bb9bb66f 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -334,7 +334,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): # pyrefly: ignore [missing-attribute] ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) - mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict] exp = Experiment(hparam_infos=hps, metric_infos=mts) diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 4fab33dc7ff09..0f533ae5b0f57 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -424,7 +424,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag if self.all_writers is None: raise AssertionError("self.all_writers is None") - if fw_tag in self.all_writers.keys(): + if fw_tag in self.all_writers: fw = self.all_writers[fw_tag] else: fw = FileWriter( diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 07097010f8f28..c9f1b660f02c5 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -287,8 +287,7 @@ def error_on_missing_kernels( expected_backend_native_funcs: list[NativeFunction] = [ f for f in native_functions - if f.func.name in expected_backend_op_names.keys() - and f.func.name not in full_codegen + if f.func.name in expected_backend_op_names and f.func.name not in full_codegen ] expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( list From 80ec2ab78e43a2f637bf5ceae753061c315eaaa5 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 12:19:56 +0000 Subject: [PATCH 123/651] [8/N] Fix unused loop variables in tests (#166921) This PR continues to fix or remove unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166921 Approved by: https://github.com/mlazos --- test/quantization/core/test_workflow_ops.py | 4 ++-- test/quantization/fx/test_quantize_fx.py | 2 +- test/test_datapipe.py | 2 +- test/test_jit_fuser_te.py | 15 ++++----------- torch/_export/serde/serialize.py | 2 +- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index f69852760e8a0..78e7799c864b1 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -368,8 +368,8 @@ def _test_forward_per_tensor_cachemask_impl(self, device): float_types = (torch.float32, torch.float16, torch.float64, torch.bfloat16) torch_types = (torch.qint8, torch.quint8) Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2]) - tensor_qparam = (True, False) - for float_type, torch_type, X, tensor_qparams in itertools.product(float_types, torch_types, Xs, tensor_qparam): + tensor_qparams = (True, False) + for float_type, torch_type, X, tensor_qparam in itertools.product(float_types, torch_types, Xs, tensor_qparams): # pick the scale + zp so that some values get clipped X = X.to(float_type) obs = torch.ao.quantization.MinMaxObserver(torch_type) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index b33afc7a80363..9c0526fde6987 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -8807,7 +8807,7 @@ def forward(self, indices, offsets): # check it works in None and static qconfig for qconfig in [None, default_qconfig]: - qconfig_dict = {"": default_qconfig} + qconfig_dict = {"": qconfig} m = M().eval() m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 5a535e7e00663..cab86e42734f1 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -1136,7 +1136,7 @@ def test_fork_iterdatapipe(self): ) break with warnings.catch_warnings(record=True) as wa: - for i, (n1, n2) in enumerate(zip(dp1, dp2)): + for n1, n2 in zip(dp1, dp2): output1.append(n1) output2.append(n2) self.assertEqual(len(wa), 1) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3018be817d9b..8622d428cb4fe 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1682,11 +1682,8 @@ def apply(fn): ] dtypes = ["int", "float", "bool"] values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} - devices = self.devices - for dtype_x, dtype_y, op, device in product( - dtypes, dtypes, binary_ops, devices - ): - code = ir_template.format(**locals()) + for dtype_x, dtype_y, op in product(dtypes, dtypes, binary_ops): + code = ir_template.format(dtype_x=dtype_x, dtype_y=dtype_y, op=op) # Interpret the graph try: @@ -1701,9 +1698,7 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError( - " ".join(["Compilation failed:", device, str(code)]) - ) from e + raise RuntimeError(" ".join(["Compilation failed:", str(code)])) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1713,9 +1708,7 @@ def apply(fn): self.assertEqual(ref, res) except Exception as e: raise RuntimeError( - " ".join( - ["Failed at runtime:", device, str(x), str(y), str(code)] - ) + " ".join(["Failed at runtime:", str(x), str(y), str(code)]) ) from e def test_matmul(self): diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 9c4629f13337d..e328422ec5e66 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -617,7 +617,7 @@ def get_triton_kernel_and_cache_entry(node: torch.fx.Node): return actual_kernel, matching_entries[0][1] if is_autotuner: - for sig_key, cache_entry in matching_entries: + for _sig_key, cache_entry in matching_entries: entry_metadata = cache_entry.metadata # pyrefly: ignore [missing-attribute] for config in kernel.configs: From b2d72a4008fa13612adc34c246e8e24c2185300e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Nov 2025 13:26:04 +0000 Subject: [PATCH 124/651] Revert "Don't hardcode double argument for reduction base (#166951)" This reverts commit a74fe75c450277eb88a95c764e8b0a664a550a86. Reverted https://github.com/pytorch/pytorch/pull/166951 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/166951#issuecomment-3497253260)) --- aten/src/ATen/native/cpu/Reduce.h | 4 ++-- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 22 +++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index ab9051ca8d2a2..6c9efbb0f6e7f 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { }); } -template -void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast(0)) { +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { using traits = binary_function_traits; static_assert( all_same< diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 053db7b4eda00..3bad49a32d98c 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) { } } +template +struct MinValuesOps: public at::native::MinOps { + using arg_t = typename MinOps::arg_t; + static scalar_t project(arg_t arg) { + return arg.first; + } +}; + void min_values_kernel_impl(TensorIterator& iter) { + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + if (iter.dtype() == kLong || iter.dtype() == kUInt64) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); + }), kLong, kUInt64); + return; + } AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, - upper_bound()); + static_cast(upper_bound())); }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } From 2005b5f54842427839edb02a6782ea92a696560a Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 07:22:07 -0800 Subject: [PATCH 125/651] [inductor] Use runtime estimations in iterative reorder collectives pass (#167080) Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K reorder iterative part Pull Request resolved: https://github.com/pytorch/pytorch/pull/167080 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 6 +- torch/_inductor/comms.py | 1123 ++++++++++++----- torch/_inductor/config.py | 25 +- torch/_inductor/config_comms.py | 47 + torch/_inductor/utils.py | 7 +- 5 files changed, 851 insertions(+), 357 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index ac3103e09341d..daa9bf2e309ff 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1985,6 +1985,7 @@ def _reorder_communication_preserving_peak_memory( "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + _reorder_communication_preserving_peak_memory, sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], @@ -2046,11 +2047,6 @@ def _reorder_communication_preserving_peak_memory( assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) self.assertEqual(len(node_stats), 4) - it = iter(node_stats.values()) - node_stat0 = next(it) - self.assertTrue(node_stat0.limiting_factor == "None") - node_stat1 = next(it) - self.assertTrue("collective ordering" in node_stat1.limiting_factor) @skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581 @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 6c7c9a8bd7dab..a4a4cac8e3ec2 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -18,7 +18,7 @@ from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet -from . import config, ir +from . import config, config_comms, ir from .dependencies import WeakDep @@ -155,12 +155,15 @@ class ReorderInfo: Debug info describing how an individual snode was reordered """ - initial_exposed: float = -1 - final_exposed: float = -1 limiting_factor: str = "None" moves: int = 0 grouped: int = 0 grouped_info: str = "" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" @property def improvement(self): @@ -193,7 +196,7 @@ def contains_gemm_like(snode: BaseSchedulerNode) -> bool: return is_gemm_like(snode.node) -def _temp_group_visit_leaves(snode, fn): +def _temp_group_visit_leaves(snode: BaseSchedulerNode, fn): from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping: @@ -203,6 +206,126 @@ def _temp_group_visit_leaves(snode, fn): fn(snode) +def wait_exposed_communication_time( + snodes_to_wait: list[BaseSchedulerNode], runtimes: dict[BaseSchedulerNode, float] +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a wait operation by finding its corresponding + collective and accumulating overlapping compute time between them. + + The Wait node must be the last in snodes_to_wait. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + wait_snode = snodes_to_wait[-1] + assert is_wait(wait_snode.node) + assert len(snodes_to_wait) > 1 + idx = len(snodes_to_wait) - 2 + comm_time = 0.0 + comp_time = 0.0 + overlap_info = "" + waits_found = [] + for i in range(idx, -1, -1): + c = snodes_to_wait[i] + if contains_wait(c): + waits_found.append(c) + if contains_collective(c): + if is_corresponding_collective_wait(c, wait_snode): + comm_time = runtimes[c] + overlap_info += f"->C[{c.get_name()}]" + break + + if not contains_async_collective(c): + # Sync Collective + comp_time = 0.0 + continue + else: + for w in waits_found: + if is_corresponding_collective_wait(c, w): + # Similar to Sync Collective + # If after our Collective exist another Collective-Wait, + # All compute after it will not be overlapping + comp_time = 0.0 + continue + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(c, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{c.get_name()}[{comp_time_after - comp_time_before}]" + + return comm_time, comp_time, overlap_info + + +def coll_exposed_communication_time( + snodes: list[BaseSchedulerNode], + runtimes: dict[BaseSchedulerNode, float], +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a collective operation by finding its corresponding + wait and accumulating compute time that can overlap with communication. + + The Collective node must be the first in snodes. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + collective_snode = snodes[0] + comm_time = runtimes[collective_snode] + comp_time = 0.0 + collective_outs: OrderedSet[str] = OrderedSet( + o.get_name() for o in collective_snode.get_outputs() + ) + overlap_info = "" + collectives_found: list[BaseSchedulerNode] = [] + for snode in snodes[1:]: + # We may have some ops without Wait, + # e.g. DTensor torch.ops._dtensor.shard_dim_alltoall + unmet_deps = OrderedSet( + d.name for d in snode.unmet_dependencies if not _is_fake_dep(d) + ) + + if unmet_deps & collective_outs: + overlap_info += f"->W[{snode.get_name()}]" + break + + if contains_collective(snode): + if not contains_async_collective(snode): + break + else: + collectives_found.append(snode) + continue + if contains_wait(snode): + has_wait_for_collectives_found = False + for coll in collectives_found: + if is_corresponding_collective_wait(collective_snode, snode): + has_wait_for_collectives_found = True + break + if has_wait_for_collectives_found: + # Any compute after not overlapping original Collective + break + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(snode, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]" + return comm_time, comp_time, overlap_info + + def _group_name(snode, with_bufs=False) -> str: ret = "" for n in snode.snodes: @@ -258,369 +381,361 @@ def _initialize_double_linked_list( return _prev, _next, _head -def _reorder_communication_preserving_peak_memory_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: +def is_corresponding_collective_wait( + collective_snode: BaseSchedulerNode, wait_snode: BaseSchedulerNode +) -> bool: """ - Internal testing helper that also returns debug info. - Returns: - - reordered snodes list - - dict {snode: ReorderInfo} + Check if a wait node corresponds to a given collective node by verifying if the wait + depends on outputs from the collective. """ - has_collectives = False - for snode in snodes: - if contains_collective(snode): - has_collectives = True - break - if not has_collectives: - return snodes, {} + collective_outs = OrderedSet(o.get_name() for o in collective_snode.get_outputs()) + unmet_deps = OrderedSet(d.name for d in wait_snode.unmet_dependencies) + return bool(unmet_deps & collective_outs) - from torch._inductor.scheduler import GroupedSchedulerNode - original_snodes_num = len(snodes) - # heuristic to avoid degenerating to quadratic time - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) - runtimes: dict[BaseSchedulerNode, float] = { - snode: estimate_op_runtime(snode) for snode in snodes - } - # debug stats - stats: dict[BaseSchedulerNode, ReorderInfo] = {} +def _op_runtime_estimate_mult(snode): + # Apply multipliers for faster experimentation. + # TODO(ivankobzarev): Remove after confirmation that runtime estimations are correct. + if contains_collective(snode): + return config_comms.reorder_sink_runtime_estimations_comm_mult - def exposed_communication_time( - collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode] - ) -> float: - # assumes a linear schedule and computes the overlap of the collective with the remaining nodes - comm_time = estimate_op_runtime(collective_snode) - compute_time = 0.0 - for snode in remaining_snodes: - if contains_collective(snode): - continue - if contains_wait(snode): - # TODO - if the wait is for a collective that started before this collective or on another stream, - # we can ignore it. Otherwise, it's the end of the road for overlap opportunities - break + return config_comms.reorder_sink_runtime_estimations_non_comm_mult - def accumulate_time(_snode: BaseSchedulerNode) -> None: - nonlocal compute_time - compute_time += runtimes[_snode] - _temp_group_visit_leaves(snode, accumulate_time) - return max(0, comm_time - compute_time) +def is_async_collective(snode): + """ + Filtering out ops that contain Collective and Wait inside and considered as Collectives. + See contains_collective function. + If the op contains Wait inside - consider as Synchronous compute. + """ + if python_kernel_name := getattr(snode.node, "python_kernel_name", None): + if "torch.ops._dtensor.shard_dim_alltoall.default" in python_kernel_name: + return False - total_moves = 0 + return True - _prev, _next, _head = _initialize_double_linked_list(snodes) - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret +def contains_async_collective(snode): + return contains_collective(snode, is_async_collective) - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # swap (candidate, group_head...group_tail) - # Before: - # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next - # After: - # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next - # 0 - candidate_prev = _prev[candidate] - if candidate_prev: - _next[candidate_prev] = group_head - _prev[group_head] = candidate_prev - - # 2 - group_tail_next = _next[group_tail] - if group_tail_next: - _prev[group_tail_next] = candidate - _next[candidate] = group_tail_next - - # 1 - _prev[candidate] = group_tail - _next[group_tail] = candidate - nonlocal _head - if _head == candidate: - _head = group_head +def _group_nodes_from_linked_list( + head: Optional[BaseSchedulerNode], + tail: Optional[BaseSchedulerNode], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], +) -> list[BaseSchedulerNode]: + """ + Traverse doubly-linked list from head to tail and return nodes as a list. - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate - ): - # Caching calculations of memory for group nodes and candidate, - # to apply without recalculation after swap. - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - potential_peak: int = 0 - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - # Not accounting for buffers last use change - potential_peak = max( - group_peak_memory - candidate_delta_mem, - _curr_memory[group_tail][1] - - candidate_delta_mem - + candidate_allocfree.size_alloc, - ) - return potential_peak, _post_alloc_update + Args: + head: Starting node of the segment + tail: Ending node of the segment (inclusive) + next_dict: Dictionary mapping each node to its next node - # If candidate will be after group, the starting memory level of group nodes - # changes to the -(candidate.size_alloc - candidate.size_free) - mem_after_reorder_delta: int = -candidate_delta_mem - for gn in gns: - gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta - _post_alloc_update[gn] = gn_post_alloc_mem - potential_peak = max(potential_peak, gn_post_alloc_mem) + Returns: + List of nodes from head to tail (inclusive) + """ + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = next_dict[n] # type: ignore[index] + return ret - bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None) - if bufs is not None: - for buf in bufs: - # Candidate will deallocate those buffers - mem_after_reorder_delta += buf.mpi_buffer.size_free - candidate_mem_post_alloc = ( - _curr_memory[group_tail][1] - + mem_after_reorder_delta - + candidate_allocfree.size_alloc +def _perform_double_linked_list_swap( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list. + + Transforms: + candidate_prev -> candidate -> group_head...group_tail -> group_tail_next + Into: + candidate_prev -> group_head...group_tail -> candidate -> group_tail_next + + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list + + Returns: + New head of the linked list (may change if candidate was the head) + """ + # 0: Update candidate's previous node + candidate_prev = prev_dict[candidate] + if candidate_prev: + next_dict[candidate_prev] = group_head + prev_dict[group_head] = candidate_prev + + # 2: Update group_tail's next node + group_tail_next = next_dict[group_tail] + if group_tail_next: + prev_dict[group_tail_next] = candidate + next_dict[candidate] = group_tail_next + + # 1: Link group_tail to candidate + prev_dict[candidate] = group_tail + next_dict[group_tail] = candidate + + # Update head if candidate was the head + if head == candidate: + return group_head + return head + + +def _calculate_potential_peak_memory_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + curr_memory: dict, +) -> tuple[int, dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (reorder version). + + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation memory values for each node. + + Args: + candidate: Node being moved + gns: Group nodes + group_tail: Last node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + curr_memory: Current memory state dict + + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict) + """ + # Caching calculations of memory for group nodes and candidate, + # to apply without recalculation after swap. + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + potential_peak: int = 0 + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + # Not accounting for buffers last use change + potential_peak = max( + group_peak_memory - candidate_delta_mem, + curr_memory[group_tail][1] + - candidate_delta_mem + + candidate_allocfree.size_alloc, ) - _post_alloc_update[candidate] = candidate_mem_post_alloc - potential_peak = max(potential_peak, candidate_mem_post_alloc) return potential_peak, _post_alloc_update - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ): - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] - candidate_delta_mem, - cm[1] - candidate_delta_mem, - ) - _candidate_post_alloc_mem = ( - _curr_memory[group_tail][1] + candidate_allocfree.size_alloc - ) - _candidate_post_free_mem = ( - _candidate_post_alloc_mem - candidate_allocfree.size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - _candidate_post_free_mem, - ) - return + # If candidate will be after group, the starting memory level of group nodes + # changes to the -(candidate.size_alloc - candidate.size_free) + mem_after_reorder_delta: int = -candidate_delta_mem + for gn in gns: + gn_post_alloc_mem = curr_memory[gn][0] + mem_after_reorder_delta + _post_alloc_update[gn] = gn_post_alloc_mem + potential_peak = max(potential_peak, gn_post_alloc_mem) - # Candidate becomes last use of some bufs - for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn) + if bufs is not None: for buf in bufs: - buf_to_snode_last_use[buf] = candidate - - size_free_to_move_to_candidate_sum: int = 0 - for n in gns: - _gn_post_alloc_mem: int = _post_alloc_update[n] - size_free_to_move_to_candidate: int = sum( - buf.mpi_buffer.size_free - for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] - ) - size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate - # group node does not deallocate this after swap - snodes_allocfree[n].size_free -= size_free_to_move_to_candidate - gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free - _curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) - _candidate_post_alloc_mem = _post_alloc_update[candidate] - snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum - candidate_post_free_mem = ( - _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - candidate_post_free_mem, - ) + # Candidate will deallocate those buffers + mem_after_reorder_delta += buf.mpi_buffer.size_free - debug_num_collectives_to_reorder: Optional[int] = ( - config.reorder_iterative_debug_limit_to_reorder + candidate_mem_post_alloc = ( + curr_memory[group_tail][1] + + mem_after_reorder_delta + + candidate_allocfree.size_alloc ) + _post_alloc_update[candidate] = candidate_mem_post_alloc + potential_peak = max(potential_peak, candidate_mem_post_alloc) + return potential_peak, _post_alloc_update + + +def _update_memory_tracking_after_swap_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + buf_to_snode_last_use: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (reorder version). - num_processed_collectives: int = 0 - curr = _head - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute - iterative_recompute_error = False - - while _next[curr] is not None: - if iterative_recompute_error: - break - # pyrefly: ignore [bad-argument-type] - if contains_collective(curr): - if debug_num_collectives_to_reorder is not None and ( - num_processed_collectives >= debug_num_collectives_to_reorder - ): - break - num_processed_collectives += 1 + Updates curr_memory, buf_to_snode_last_use, and snodes_allocfree dictionaries + to reflect the new memory state after swapping candidate with group. - info = stats[curr] = ReorderInfo() - info.initial_exposed = info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) + Args: + candidate: Node that was moved + gns: Group nodes + group_tail: Last node of group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + post_alloc_update: Cached post-allocation memory values + curr_memory: Current memory state dict (mutated) + buf_to_snode_last_use: Buffer to last-use node mapping (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + for gn in gns: + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] - candidate_delta_mem, + cm[1] - candidate_delta_mem, ) + _candidate_post_alloc_mem = ( + curr_memory[group_tail][1] + candidate_allocfree.size_alloc + ) + _candidate_post_free_mem = ( + _candidate_post_alloc_mem - candidate_allocfree.size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + _candidate_post_free_mem, + ) + return - candidate = _prev[curr] - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] # post_alloc memory - while candidate is not None: - if contains_collective(candidate): - info.limiting_factor = "collective ordering" - break - - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - curr.scheduler, - gns, - temp_grouping=True, - ) - - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) - } - - candidate_outs = candidate.get_outputs() - data_dep = None - for o in candidate_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break + # Candidate becomes last use of some bufs + for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + for buf in bufs: + buf_to_snode_last_use[buf] = candidate + + size_free_to_move_to_candidate_sum: int = 0 + for n in gns: + _gn_post_alloc_mem: int = post_alloc_update[n] + size_free_to_move_to_candidate: int = sum( + buf.mpi_buffer.size_free + for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] + ) + size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate + # group node does not deallocate this after swap + snodes_allocfree[n].size_free -= size_free_to_move_to_candidate + gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free + curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) + _candidate_post_alloc_mem = post_alloc_update[candidate] + snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum + candidate_post_free_mem = ( + _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + candidate_post_free_mem, + ) - if data_dep is not None: - def is_groupable( - candidate: BaseSchedulerNode, - ) -> tuple[bool, Optional[str]]: - # preserve ordering - if contains_collective(candidate): - return False, "contains_collective" +def _find_buffers_with_changed_last_use( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping candidate with group. - if contains_gemm_like(candidate): - return False, "contains_gemm_like" - return True, None + When we swap [candidate [group]] to [[group] candidate], some buffers that + were last used by a group node will now be last used by candidate instead. + This affects memory deallocation timing. - is_groupable_result, grouping_reason = is_groupable(candidate) - if is_groupable_result: - group_head = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _prev[candidate] - continue - else: - msg = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" - f"dep on {_group_names(gns)}" - f"\n non_group_reason:{grouping_reason}" - ) - info.limiting_factor = msg - break + Args: + candidate: The node being moved + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem: int = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # candidate and one of group nodes are successors of the same buffer - # and last use of the buffer happen in group nodes. - # This last use deallocates it. - # If we swap [candidate [group]] to [[group] candidate], - # candidate becomes the last use - # and deallocated this buffer instead of group node. - # we need to update size_free accordingly to group_node and candidate, - # and recalculate post_alloc, post_free for them. - # - # Buf that changes its last use snode, - # after swap will be deallocated only by candidate, - # while before it was deallocated by group node. - group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if candidate not in succ_nodes: - continue + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if candidate not in succ_nodes: + continue - if not any(gn == snode_last_use for gn in gns): - continue + if not any(gn == snode_last_use for gn in gns): + continue - group_n_to_bufs_after_swap_dealloc_by_candidate[ - snode_last_use - ].append(buf) + group_n_to_bufs_after_swap_dealloc_by_candidate[snode_last_use].append(buf) - potential_peak, _post_alloc_update = _calculate_potential_peak_memory( - candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate - ) + return group_n_to_bufs_after_swap_dealloc_by_candidate - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break - info.moves += 1 - total_moves += 1 - _perform_double_linked_list_swap(candidate, group_head, group_tail) +def _is_node_groupable_for_reorder( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped with collective during reordering. - info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) - ) + This pass processes collectives left to right, so we avoid grouping with + already-processed collectives based on configuration. - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ) + Args: + candidate: Node to check for groupability - if debug_iterative_memory_recompute: - # Compare iteratively recomputed memory data - # with full run of estimate_peak_memory + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # This pass processes collectives left to right, + # Do not group with processed collectives. + # Leaving config for experimentation in 2D + if not config_comms.reorder_iterative_group_with_collectives: + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_collective {candidate.get_name()}", + ) + if not config_comms.reorder_iterative_use_runtime_estimations: + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + +def _format_and_log_reordering_stats( + stats: dict[BaseSchedulerNode, ReorderInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format reordering statistics, log them, and return final node list. - from .comms_debug import _debug_iterative_memory_recompute + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "reorder_communication_preserving_peak_memory", - group_n_to_bufs_after_swap_dealloc_by_candidate, - ) - if iterative_recompute_error: - break - candidate = _prev[group_head] - curr = _next[curr] # type: ignore[assignment] + Args: + stats: Per-node reordering statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} total_improvement = sum([improvement[snode] for snode in improvement]) @@ -632,28 +747,35 @@ def is_groupable( ) headers = [ "Collective node", - "initial exposed", - "final exposed", - "improvement", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", "limiting factor", "moves", "grouped", "grouped_info", + "overlap_info", ] rows = [ [ node_summary(snode), - node_info.initial_exposed, - node_info.final_exposed, - node_info.improvement, + node_info.comm_time / 1e3, + node_info.comp_time / 1e3, + node_info.initial_exposed / 1e3, + node_info.final_exposed / 1e3, + node_info.improvement / 1e3, node_info.limiting_factor, node_info.moves, node_info.grouped, node_info.grouped_info, + node_info.overlap_info, ] for snode, node_info in node_stats.items() ] if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate reorder_log_str += tabulate( @@ -667,7 +789,7 @@ def is_groupable( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -685,6 +807,334 @@ def is_groupable( payload_fn=lambda: reorder_log_str, ) + return new_snodes + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + # heuristic to avoid degenerating to quadratic time + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + total_moves = 0 + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + debug_num_collectives_to_reorder: Optional[int] = ( + config_comms.reorder_iterative_debug_limit_to_reorder + ) + + num_processed_collectives: int = 0 + curr: Optional[BaseSchedulerNode] = _head + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + iterative_recompute_error = False + + while curr is not None and _next[curr] is not None: + _next_curr = _next[curr] + if iterative_recompute_error: + break + # pyrefly: ignore [bad-argument-type] + if not contains_async_collective(curr): + curr = _next_curr + continue + + if debug_num_collectives_to_reorder is not None and ( + num_processed_collectives >= debug_num_collectives_to_reorder + ): + break + num_processed_collectives += 1 + + info = stats[curr] = ReorderInfo() + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_waits = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] # post_alloc memory + + while candidate is not None: + if config_comms.reorder_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.reorder_iterative_extra_comm_comp_overlap + * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + if ( + not config_comms.reorder_iterative_unsafe_collectives_reorder + and contains_collective(candidate) + ): + info.limiting_factor = "collective ordering" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + curr.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) + } + + candidate_outs = candidate.get_outputs() + data_dep = None + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + + if data_dep is not None: + is_groupable_result, grouping_reason = _is_node_groupable_for_reorder( + candidate + ) + if is_groupable_result: + group_head = candidate + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + if contains_wait(candidate): + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), + runtimes, + ) + group_waits[candidate] = comm_time, comp_time + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _prev[candidate] + continue + else: + msg = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(gns)}" + f"\n non_group_reason:{grouping_reason}" + ) + info.limiting_factor = msg + break + + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_waits) > 0: + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, info.comm_time - info.comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max( + 0, info.comm_time - info.comp_time - c_runtime + ) + exposed_delta = exposed_after - exposed_before + for gw_comm_time, gw_comp_time in group_waits.values(): + gw_exposed_before = max(0, gw_comm_time - gw_comp_time) + gw_exposed_after = max( + 0, gw_comm_time - gw_comp_time + c_runtime + ) + + exposed_delta += gw_exposed_after - gw_exposed_before + + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}," + f" group contains waits, total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gw, ( + gw_comm_time, + gw_comp_time, + ) in group_waits.items(): + group_waits[gw] = ( + gw_comm_time, + gw_comp_time - c_runtime, + ) + else: + # Candidate is async_collective + + # Unsafe collectives reordering + # Cj -> [...group_runtime..., Ci] -> Wj + # Checking that we are not increasing exposed time of Cj + if group_runtime > 0: + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + exposed_delta = exposed_after - exposed_before + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate {candidate.get_name()} is collective," + f" group_runtime:{group_runtime}," + f" exposed_delta:{exposed_delta} c_comm_time:{comm_time} c_comp_time:{comp_time}" + ) + break + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem: int = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # candidate and one of group nodes are successors of the same buffer + # and last use of the buffer happen in group nodes. + # This last use deallocates it. + # If we swap [candidate [group]] to [[group] candidate], + # candidate becomes the last use + # and deallocated this buffer instead of group node. + # we need to update size_free accordingly to group_node and candidate, + # and recalculate post_alloc, post_free for them. + # + # Buf that changes its last use snode, + # after swap will be deallocated only by candidate, + # while before it was deallocated by group node. + group_n_to_bufs_after_swap_dealloc_by_candidate = ( + _find_buffers_with_changed_last_use( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update = ( + _calculate_potential_peak_memory_reorder( + candidate, + gns, + group_tail, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _curr_memory, + ) + ) + + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.reorder_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + info.moves += 1 + total_moves += 1 + + _head = _perform_double_linked_list_swap( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + info.final_exposed = comm_time - comp_time + + _update_memory_tracking_after_swap_reorder( + candidate, + gns, + group_tail, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _post_alloc_update, + _curr_memory, + buf_to_snode_last_use, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + # Compare iteratively recomputed memory data + # with full run of estimate_peak_memory + + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "reorder_communication_preserving_peak_memory", + group_n_to_bufs_after_swap_dealloc_by_candidate, + ) + if iterative_recompute_error: + break + candidate = _prev[group_head] + curr = _next_curr + + new_snodes = _format_and_log_reordering_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + return new_snodes, stats @@ -1012,9 +1462,11 @@ def _update_memory_tracking_after_swap( curr = snodes[-1] processed_waits = OrderedSet() # type: ignore[var-annotated] - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) debug_num_sink_waits_to_reorder: Optional[int] = ( - config.sink_waits_iterative_debug_limit_to_sink + config_comms.sink_waits_iterative_debug_limit_to_sink ) iterative_recompute_error = False @@ -1213,6 +1665,7 @@ def is_groupable(snode): ] log_str = "" if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate log_str += tabulate( @@ -1224,7 +1677,7 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(_head, None, _next) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -1267,7 +1720,7 @@ def node_summary(snode): if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" - detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index aaf7fbd2f7f54..2d9e180db54f5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -379,6 +379,15 @@ def prologue_fusion_enabled() -> bool: # for built-in passes, use string name; for user-defined passes, pass in the function handle # WARNING: Inductor scheduler IR is at prototype stage and subject to change, # hence custom IR passes built on top of it might break in the future. +# +# See aten_distributed_optimizations, it is recommended way for distributed optimizations. +# +# Recommended configuration for reorder_for_compute_comm_overlap_passes: +# [ +# "reorder_communication_preserving_peak_memory", +# "sink_waits_iterative", +# "reorder_communication_preserving_peak_memory", +# ] reorder_for_compute_comm_overlap_passes: list[ Union[ str, @@ -387,11 +396,7 @@ def prologue_fusion_enabled() -> bool: list["torch._inductor.scheduler.BaseSchedulerNode"], ], ] -] = [ - "reorder_compute_for_overlap", - "sink_waits", - "raise_comms", -] +] = [] # Maximum number of positions to advance a given collective, unlimited by default reorder_prefetch_limit: Optional[int] = None @@ -407,16 +412,6 @@ def prologue_fusion_enabled() -> bool: # is zero, which turns off this optimization. size_threshold_for_succ_based_strategy: int = 0 -reorder_iterative_debug_memory_recompute: bool = False -reorder_iterative_debug_limit_to_reorder: Optional[int] = ( - None - if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None - else int(env_str) -) -sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( - # pyrefly: ignore [unbound-name] - None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) -) bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py index b5dbf424f35b4..51242c7f2cf5b 100644 --- a/torch/_inductor/config_comms.py +++ b/torch/_inductor/config_comms.py @@ -1,4 +1,6 @@ +import os import sys +from typing import Optional from torch.utils._config_module import install_config_module @@ -11,5 +13,50 @@ # decisions on different distributed ranks. runtime_estimations_align_across_all_distributed_ranks: bool = False +reorder_iterative_debug_memory_recompute: bool = False +reorder_iterative_debug_limit_to_reorder: Optional[int] = ( + None + # pyrefly: ignore[unbound-name] + if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None + else int(env_str) +) +sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( + # pyrefly: ignore[unbound-name] + None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) +) + + +# Should be used with config.runtime_estimations_mms_benchmark = True +reorder_iterative_use_runtime_estimations: bool = False +sink_iterative_use_runtime_estimations: bool = False + +# Broadcast runtime estimations doing real Collective operation between all ranks. +# If non-deterministic runtime estimations are used this must be used to make +# all ranks to do identical decisions and prevent global Collectives reordering, +# (that will result un NCCL hangs) +reorder_for_compute_comm_overlap_broadcast_runtime_estimations: bool = False + +# Block of Ratios to workaround imperfection of current runtime estimations +# for collectives and compute for different scenarios. +# Multiplier of collectives estimated durations +reorder_sink_runtime_estimations_comm_mult: float = 2.0 +# Multiplier of compute estimated durations +reorder_sink_runtime_estimations_non_comm_mult: float = 1.0 +# The reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive overlap +reorder_iterative_extra_comm_comp_overlap: float = 0.5 + +# Allow reorder iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +reorder_iterative_peak_memory_budget: float = 0.2 + +# Experimental unsafe configuration that allows changing relative collectives order. +# Must be used with runtime_estimations_align_across_all_distributed_ranks = True +reorder_iterative_unsafe_collectives_reorder: bool = True + +# Allow group and move other collectives during reordering +reorder_iterative_group_with_collectives: bool = False + # adds patch, save_config, etc install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f8652882af79..9579dbb3536e3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2813,13 +2813,16 @@ def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: return type(node) is ir._WaitKernel -def contains_collective(snode: BaseSchedulerNode) -> bool: +def contains_collective( + snode: BaseSchedulerNode, + filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None, +) -> bool: from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode): return any(contains_collective(x) for x in snode.snodes) - return is_collective(snode.node) + return is_collective(snode.node) and (filter_fn is None or filter_fn(snode)) def contains_wait(snode: BaseSchedulerNode) -> bool: From da2eb31b824820666445e3e232007f26eb825e28 Mon Sep 17 00:00:00 2001 From: Jessica Vandebon Date: Thu, 6 Nov 2025 15:43:45 +0000 Subject: [PATCH 126/651] [MTIA][PyTorch] Add mtia as native device for PyTorch tests (#167089) Summary: Add MTIA as a native device type in PyTorch. Test Plan: CI Reviewed By: PatriceVignola Differential Revision: D80111801 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167089 Approved by: https://github.com/andyanwang, https://github.com/nautsimon, https://github.com/albanD --- torch/testing/_internal/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 0c26738c2f52f..00572f9691380 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -333,7 +333,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', 'mtia', torch._C._get_privateuse1_backend_name()) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs From 7b055a0103008b84292dba154448547af424c739 Mon Sep 17 00:00:00 2001 From: Lakshay Garg Date: Thu, 6 Nov 2025 16:10:16 +0000 Subject: [PATCH 127/651] Add per_process_memory_fraction to PYTORCH_CUDA_ALLOC_CONF (#161035) torch.cuda.memory.set_per_process_memory_fraction allows setting an upper bound on how much device memory is allocated. This PR exposes this setting to an environment variable. For example, PYTORCH_CUDA_ALLOC_CONF="per_process_memory_fraction:0.5" will limit the device memory to half of the available memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161035 Approved by: https://github.com/ngimel, https://github.com/eqy --- c10/cuda/CUDAAllocatorConfig.cpp | 15 ++++++ c10/cuda/CUDAAllocatorConfig.h | 11 ++++- c10/cuda/CUDACachingAllocator.cpp | 67 +++++++++++++-------------- c10/cuda/CUDACachingAllocator.h | 1 + c10/cuda/CUDAMallocAsyncAllocator.cpp | 1 - docs/source/notes/cuda.rst | 4 ++ test/test_cuda.py | 46 ++++++++++++++++++ 7 files changed, 108 insertions(+), 37 deletions(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 3046259b48a3e..5414d838cd8c4 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -106,6 +106,9 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } else if (key == "graph_capture_record_stream_reuse") { i = parseGraphCaptureRecordStreamReuse(tokenizer, i); used_native_specific_option = true; + } else if (key == "per_process_memory_fraction") { + i = parsePerProcessMemoryFraction(tokenizer, i); + used_native_specific_option = true; } else { const auto& keys = c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); @@ -146,6 +149,18 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse( return i; } +double CUDAAllocatorConfig::parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { + tokenizer.checkToken(++i, ":"); + double val_env = tokenizer.toDouble(++i); + TORCH_CHECK_VALUE( + val_env >= 0.0 && val_env <= 1.0, + "per_process_memory_fraction is invalid, set it in [0.0, 1.0]"); + m_per_process_memory_fraction = val_env; + return i; +} + size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i) { diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index d61f69467a2dc..4e6097a406bc2 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -61,6 +61,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_graph_capture_record_stream_reuse; } + static double per_process_memory_fraction() { + return instance().m_per_process_memory_fraction; + } + /** Pinned memory allocator settings */ static bool pinned_use_cuda_host_register() { return instance().m_pinned_use_cuda_host_register; @@ -152,7 +156,8 @@ class C10_CUDA_API CUDAAllocatorConfig { "pinned_use_hip_host_register", "graph_capture_record_stream_reuse", "pinned_reserve_segment_size_mb", - "pinned_num_register_threads"}; + "pinned_num_register_threads", + "per_process_memory_fraction"}; return keys; } @@ -177,6 +182,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parseGraphCaptureRecordStreamReuse( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); + double parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); std::atomic m_pinned_num_register_threads{1}; std::atomic m_pinned_reserve_segment_size_mb{0}; @@ -189,6 +197,7 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_release_lock_on_cudamalloc{false}; std::atomic m_pinned_use_cuda_host_register{false}; std::atomic m_graph_capture_record_stream_reuse{false}; + std::atomic m_per_process_memory_fraction{1.0}; }; // Keep this for backwards compatibility diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 091e580f95819..d66c3a16c0004 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1100,7 +1100,7 @@ class RingBuffer { } // anonymous namespace } // namespace Native -static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { +static std::string reportProcessMemoryInfo(const cudaDeviceProp& prop) { #ifdef PYTORCH_C10_DRIVER_API_SUPPORTED void* nvml_handle = DriverAPI::get_nvml_handle(); if (!nvml_handle) { @@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { return true; }(); - cudaDeviceProp prop{}; - C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - // NOLINTNEXTLINE(*-c-arrays) char pci_id[80]; snprintf( @@ -1215,14 +1212,16 @@ class DeviceCachingAllocator { // record used memory. size_t total_allocated_memory = 0; - size_t allowed_memory_maximum = 0; + cudaDeviceProp device_prop; + + // maximum amount of memory that device is allowed to + // allocate. This is set iff memory fraction is less than 1 + std::optional allowed_memory_maximum{std::nullopt}; // all live expandable segments std::vector expandable_segments_; std::vector devices_with_peer_access_; - bool set_fraction = false; - bool record_history = false; std::atomic context_recorder_; @@ -1264,6 +1263,9 @@ class DeviceCachingAllocator { : device_id(id), large_blocks(/*small=*/false), small_blocks(/*small=*/true) { + C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id)); + + setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction()); stats.max_split_size = static_cast(AcceleratorAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); @@ -1399,7 +1401,7 @@ class DeviceCachingAllocator { if (!block_found) { // Do garbage collection if the flag is set. if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); @@ -1456,11 +1458,12 @@ class DeviceCachingAllocator { C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); std::string allowed_info; - if (set_fraction) { - allowed_info = format_size(allowed_memory_maximum) + " allowed; "; + if (allowed_memory_maximum.has_value()) { + allowed_info = + format_size(allowed_memory_maximum.value()) + " allowed; "; } - std::string proc_info = reportProcessMemoryInfo(device_id); + std::string proc_info = reportProcessMemoryInfo(device_prop); record_trace( TraceEntry::OOM, @@ -1518,7 +1521,7 @@ class DeviceCachingAllocator { for (const auto& obs : observers_local) { obs(device_id, alloc_size, - set_fraction ? allowed_memory_maximum : device_total, + allowed_memory_maximum.value_or(device_total), device_free); } @@ -2015,25 +2018,26 @@ class DeviceCachingAllocator { /** get memory fraction limiting maximum allocated memory **/ double getMemoryFraction() { - if (!set_fraction) { + if (!allowed_memory_maximum.has_value()) { return 1.0; } - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - return static_cast(allowed_memory_maximum) / - static_cast(device_total); + return static_cast(allowed_memory_maximum.value()) / + static_cast(device_prop.totalGlobalMem); } /** set memory fraction to limit maximum allocated memory **/ void setMemoryFraction(double fraction) { - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - allowed_memory_maximum = - static_cast(fraction * static_cast(device_total)); - set_fraction = true; + TORCH_CHECK( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within [0, 1]."); + allowed_memory_maximum = std::nullopt; + if (fraction < 1.0) { + allowed_memory_maximum = static_cast( + fraction * static_cast(device_prop.totalGlobalMem)); + } } /** get expandable segment size for all the streams on device **/ @@ -3010,7 +3014,7 @@ class DeviceCachingAllocator { BlockPool& pool = *p.pool; if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; @@ -3083,7 +3087,7 @@ class DeviceCachingAllocator { size_t gc_threshold = static_cast( AcceleratorAllocatorConfig::garbage_collection_threshold() * - static_cast(allowed_memory_maximum)); + static_cast(allowed_memory_maximum.value())); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { return; @@ -3161,8 +3165,8 @@ class DeviceCachingAllocator { bool active_pool = p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); - if (set_fraction && - total_allocated_memory + size > allowed_memory_maximum) { + if (allowed_memory_maximum.has_value() && + total_allocated_memory + size > allowed_memory_maximum.value()) { p.err = cudaErrorMemoryAllocation; return false; // Temporarily disable checkpointing & cudagraphs internally @@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); return device_allocator[device]->getMemoryFraction(); } @@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - TORCH_CHECK( - 0 <= fraction && fraction <= 1, - "invalid fraction:", - fraction, - ". Please set within [0, 1]."); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); device_allocator[device]->setMemoryFraction(fraction); } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index fbe5dab18e0ae..8fee00dd621dc 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 93bce51f1b9d0..674eb00035c50 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -427,7 +427,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { // on the current device each later call sees. void init(int dev_count) override { static bool called = [](int dev_count) { - ; // Are there external guarantees init will be called before // any of the allocator's other functions? // std::lock_guard lk(general_mutex); diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index caabeb399c722..2c1a2e8cbb6be 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -619,6 +619,10 @@ Available options: and reallocate buffers across multiple streams, especially when the capture DAG frequently reaches joined frontiers. +* ``per_process_memory_fraction`` option limits the amount of memory that can be allocated + on all the CUDA devices to a specified fraction of the available memory. This is a value + between 0 and 1. Attempting to allocate more memory will raise an out of memory error. + .. note:: Some stats reported by the diff --git a/test/test_cuda.py b/test/test_cuda.py index 329261fba7d3a..dfbcdc1b40401 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4626,6 +4626,52 @@ def check_output(script: str) -> str: rc = check_output(test_script) self.assertEqual(rc, "cudaMallocAsync") + def test_allocator_memory_fraction_setting(self): + def make_env(fraction): + env = os.environ.copy() + var = "PYTORCH_CUDA_ALLOC_CONF" + key = "per_process_memory_fraction" + value = [ + x + for x in env.get(var, "").split(",") + if len(x) > 0 and not x.startswith(f"{key}:") + ] + value.append(f"{key}:{fraction}") + env[var] = ",".join(value) + return env + + def run_test(value): + test_script = """\ +import os +import torch +device = torch._C._cuda_getDevice() +value = torch.cuda.memory.get_per_process_memory_fraction(device) +print(value, end="") + """ + return subprocess.run( + [sys.executable, "-c", test_script], + env=make_env(value), + text=True, + check=True, + capture_output=True, + ) + + self.assertEqual(run_test(0.0).stdout, "0.0") + self.assertEqual(run_test(0.5).stdout, "0.5") + self.assertEqual(run_test(1.0).stdout, "1.0") + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(-0.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(1.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + def test_cachingAllocator_raw_alloc(self): # Test that raw_alloc respects the setting that # activates/deactivates the caching allocator From cc477f600968a89a0e080ccaa6052277543bc84b Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 07:22:13 -0800 Subject: [PATCH 128/651] [inductor] Use runtime estimations in iterative sink waits pass (#167081) Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K reorder iterative part Pull Request resolved: https://github.com/pytorch/pytorch/pull/167081 Approved by: https://github.com/eellison ghstack dependencies: #167080 --- torch/_inductor/comms.py | 932 ++++++++++++++++++++++---------- torch/_inductor/config_comms.py | 9 + 2 files changed, 644 insertions(+), 297 deletions(-) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index a4a4cac8e3ec2..29efcb4a44493 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from .ir import IRNode, Operation - from .scheduler import SchedulerBuffer from .memory import ( estimate_peak_memory, @@ -1325,341 +1324,289 @@ class SinkWaitInfo: moves: int = 0 moves_info: str = "" limiting_factor: str = "None" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" + @property + def improvement(self): + return self.initial_exposed - self.final_exposed -def _sink_waits_iterative_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode - original_snodes_num = len(snodes) - if original_snodes_num == 0: - return snodes, {} - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) +def _is_node_groupable_for_sink_waits( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped during sink_waits pass. - _prev, _next, _head = _initialize_double_linked_list(snodes) + Sink Waits traverses waits right to left, so we don't group with + processed waits on the right or with async collectives. - stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret + Args: + candidate: Node to check for groupability - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate - ): - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # Sink Waits traverse Waits right to left, + # => we do not group with processed Waits on the right. + if contains_wait(candidate): + return False, f"candidate contains wait {candidate.get_name()}" + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_async_collective {candidate.get_name()}", ) - # Stash memory tracing updates to not recompute them after swap - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - _size_free_delta_update: dict[BaseSchedulerNode, int] = {} - - potential_peak = 0 - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - # Not accounting for buffers liveliness change - potential_peak = max( - group_peak_memory + candidate_delta_mem, - pre_group_mem + candidate_allocfree.size_alloc, + + # pyrefly: ignore[unbound-name] + if not config_comms.sink_iterative_use_runtime_estimations: + # Heuristics pre-use_runtime_estimations: + # TODO(ivankobzarev): Remove them after confirming, + # that using runtime estimations always give better results. + # We do not want to group with collectives to not reorder them forward. + if contains_collective(candidate): + return ( + False, + f"candidate contains collective {candidate.get_name()}", + ) + if contains_gemm_like(candidate): + return ( + False, + f"candidate contains gemm_like {candidate.get_name()}", ) - return potential_peak, _post_alloc_update, _size_free_delta_update + return True, None + + +def _update_memory_tracking_after_swap_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + size_free_delta_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (sink_waits version). + Updates curr_memory and snodes_allocfree dictionaries to reflect the new + memory state after swapping candidate with group. + + Args: + candidate: Node that was moved + gns: Group nodes + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + post_alloc_update: Cached post-allocation memory values + size_free_delta_update: Cached size-free delta values + curr_memory: Current memory state dict (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + group_head = gns[0] + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _post_alloc_update[candidate] = candidate_post_alloc - potential_peak = candidate_post_alloc - candidate_size_free_to_move = sum( - buf.mpi_buffer.size_free # type: ignore[attr-defined] - for buf in itertools.chain.from_iterable( - group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() - ) + curr_memory[candidate] = ( + candidate_post_alloc, + candidate_post_alloc - candidate_allocfree.size_free, ) - _size_free_delta_update[candidate] = -candidate_size_free_to_move - delta_mem = candidate_delta_mem + candidate_size_free_to_move for gn in gns: - gn_post_alloc = _curr_memory[gn][0] + delta_mem - _post_alloc_update[gn] = gn_post_alloc - potential_peak = max(potential_peak, gn_post_alloc) - gn_size_free_to_add = 0 - if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] - for buf in bufs: - gn_size_free_to_add += buf.mpi_buffer.size_free - _size_free_delta_update[gn] = gn_size_free_to_add - delta_mem -= gn_size_free_to_add - return potential_peak, _post_alloc_update, _size_free_delta_update + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] + candidate_delta_mem, + cm[1] + candidate_delta_mem, + ) + return - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next - # 0: - group_head_prev = _prev[group_head] - if group_head_prev: - _next[group_head_prev] = candidate - _prev[candidate] = group_head_prev - - # 2: - candidate_next = _next[candidate] - if candidate_next: - _prev[candidate_next] = group_tail - _next[group_tail] = candidate_next - - # 1: - _prev[group_head] = candidate - _next[candidate] = group_head - nonlocal _head - if group_head == _head: - _head = candidate - - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ): - group_head = gns[0] - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + for n in [candidate, *gns]: + post_alloc = post_alloc_update[n] + snodes_allocfree[n].size_free += size_free_delta_update.get(n, 0) + curr_memory[n] = ( + post_alloc, + post_alloc - snodes_allocfree[n].size_free, ) - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _curr_memory[candidate] = ( - candidate_post_alloc, - candidate_post_alloc - candidate_allocfree.size_free, - ) - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] + candidate_delta_mem, - cm[1] + candidate_delta_mem, - ) - return - - for n in [candidate, *gns]: - post_alloc = _post_alloc_update[n] - snodes_allocfree[n].size_free += _size_free_delta_update[n] - _curr_memory[n] = ( - post_alloc, - post_alloc - snodes_allocfree[n].size_free, - ) - curr = snodes[-1] - processed_waits = OrderedSet() # type: ignore[var-annotated] - debug_iterative_memory_recompute = ( - config_comms.reorder_iterative_debug_memory_recompute - ) - debug_num_sink_waits_to_reorder: Optional[int] = ( - config_comms.sink_waits_iterative_debug_limit_to_sink - ) +def _calculate_potential_peak_memory_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_head: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + curr_memory: dict, + snodes_allocfree: dict, +) -> tuple[int, dict[BaseSchedulerNode, int], dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (sink_waits version). - iterative_recompute_error = False + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation and size-free delta values. - while _prev[curr] is not None: - if iterative_recompute_error: - break - if ( - debug_num_sink_waits_to_reorder is not None - and len(processed_waits) >= debug_num_sink_waits_to_reorder - ): - break + Args: + candidate: Node being moved + gns: Group nodes + group_head: First node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + curr_memory: Current memory state dict + snodes_allocfree: Allocation/free info for all nodes - # pyrefly: ignore [bad-argument-type] - if contains_wait(curr) and curr not in processed_waits: - processed_waits.add(curr) - info = stats[curr] = SinkWaitInfo() - candidate = _next[curr] - wait_snode = curr - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] - while candidate is not None: - if iterative_recompute_error: - break - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - wait_snode.scheduler, - gns, - temp_grouping=True, - ) + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict, size_free_delta_update_dict) + """ + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + # Stash memory tracing updates to not recompute them after swap + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + _size_free_delta_update: dict[BaseSchedulerNode, int] = {} - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d - for d in candidate.unmet_dependencies - if not _is_fake_dep(d) - } - - group_outs = group.get_outputs() - data_dep = None - for o in group_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break - # 1. If we have data_dep - we can not swap => trying to group - # 2. If swap candidate and current node both contain collectives => trying to group - if data_dep is not None or ( - both_contain_comms := ( - contains_collective(group) and contains_collective(candidate) - ) - ): + potential_peak = 0 + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + # Not accounting for buffers liveliness change + potential_peak = max( + group_peak_memory + candidate_delta_mem, + pre_group_mem + candidate_allocfree.size_alloc, + ) + return potential_peak, _post_alloc_update, _size_free_delta_update - def is_groupable(snode): - # We do not want to group with collectives to not reorder them forward. - if contains_collective(snode): - return ( - False, - f"candidate contains collective {snode.get_name()}", - ) - if contains_gemm_like(snode): - return ( - False, - f"candidate contains gemm_like {snode.get_name()}", - ) - return True, None + candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc + _post_alloc_update[candidate] = candidate_post_alloc + potential_peak = candidate_post_alloc + candidate_size_free_to_move = sum( + buf.mpi_buffer.size_free # type: ignore[attr-defined] + for buf in itertools.chain.from_iterable( + group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() + ) + ) + _size_free_delta_update[candidate] = -candidate_size_free_to_move + delta_mem = candidate_delta_mem + candidate_size_free_to_move + for gn in gns: + gn_post_alloc = curr_memory[gn][0] + delta_mem + _post_alloc_update[gn] = gn_post_alloc + potential_peak = max(potential_peak, gn_post_alloc) + gn_size_free_to_add = 0 + if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] + for buf in bufs: + gn_size_free_to_add += buf.mpi_buffer.size_free + _size_free_delta_update[gn] = gn_size_free_to_add + delta_mem -= gn_size_free_to_add + return potential_peak, _post_alloc_update, _size_free_delta_update - is_grp, grp_reason = is_groupable(candidate) - if is_grp: - group_tail = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _next[candidate] - continue - elif (data_dep is None) and both_contain_comms: - info.limiting_factor = ( - f"collective ordering {_group_names(gns)}" - f" with candidate:{candidate.get_name()}" - ) - break - else: - info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" - f"dep on {gns}" - f"\n outs:{[o.get_name() for o in group_outs]}" - f"\n non_group_reason:{grp_reason}" - ) - break - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # [group] candidate -> candidate [group] - # Check for buffers with successors in group and candidate last successor - # - # Buf that changes its last use snode, - # It was deallocated by candidate, - # but after swap it will be deallocated by group node. - group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if snode_last_use != candidate: # noqa: E711 - continue - # candidate is last use of buf - last_succ_gn = None - for gn in gns: - if gn in succ_nodes: - last_succ_gn = gn - if last_succ_gn is None: - continue +def _perform_double_linked_list_swap_sink_waits( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list (sink_waits version). - # gn has successors of buf that after potential swap will become - # last use of buf and start deallocating buf instead of candidate - group_n_to_bufs_after_swap_dealloc_instead_of_candidate[ - last_succ_gn - ].append(buf) - - potential_peak, _post_alloc_update, _size_free_delta_update = ( - _calculate_potential_peak_memory( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - ) - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break + Transforms (moves candidate to the left): + group_head_prev -> group_head...group_tail -> candidate -> candidate_next + Into: + group_head_prev -> candidate -> group_head...group_tail -> candidate_next - info.moves += 1 - info.moves_info += f"+{candidate.get_name()}" + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list - _perform_double_linked_list_swap(candidate, group_head, group_tail) + Returns: + New head of the linked list (may change if group_head was the head) + """ + # 0: Update group_head's previous node + group_head_prev = prev_dict[group_head] + if group_head_prev: + next_dict[group_head_prev] = candidate + prev_dict[candidate] = group_head_prev + + # 2: Update candidate's next node + candidate_next = next_dict[candidate] + if candidate_next: + prev_dict[candidate_next] = group_tail + next_dict[group_tail] = candidate_next + + # 1: Link candidate to group_head + prev_dict[group_head] = candidate + next_dict[candidate] = group_head + + # Update head if group_head was the head + if group_head == head: + return candidate + return head - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ) - if debug_iterative_memory_recompute: - from .comms_debug import _debug_iterative_memory_recompute - - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "sink_waits_iterative", - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - if iterative_recompute_error: - break +def _format_and_log_sink_waits_stats( + stats: dict[BaseSchedulerNode, SinkWaitInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format sink_waits statistics, log them, and return final node list. + + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - candidate = _next[group_tail] - curr = _prev[curr] # type: ignore[assignment] + Args: + stats: Per-node sink_waits statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ headers = [ "Wait node", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", + "limiting factor", "grouped", "grouped_info", "moves", "moves_info", - "limiting factor", + "overlap_info", ] rows = [ [ node_summary(snode), + info.comm_time / 1e3, + info.comp_time / 1e3, + info.initial_exposed / 1e3, + info.final_exposed / 1e3, + info.improvement / 1e3, + info.limiting_factor, info.grouped, info.grouped_info, info.moves, info.moves_info, - info.limiting_factor, + info.overlap_info, ] for snode, info in stats.items() ] @@ -1677,7 +1624,7 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes_from_linked_list(_head, None, _next) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -1692,18 +1639,409 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - return new_snodes, stats + return new_snodes + + +def _find_buffers_with_changed_last_use_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping in sink_waits pass. + When we swap [group] candidate to candidate [group], some buffers that + were last used by candidate will now be last used by a group node instead. + This is the opposite direction from the reorder version. -def sink_waits_iterative( + Args: + candidate: The node being moved (currently last use) + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes + + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if snode_last_use != candidate: # noqa: E711 + continue + # candidate is last use of buf + last_succ_gn = None + for gn in gns: + if gn in succ_nodes: + last_succ_gn = gn + if last_succ_gn is None: + continue + + # gn has successors of buf that after potential swap will become + # last use of buf and start deallocating buf instead of candidate + group_n_to_bufs_after_swap_dealloc_instead_of_candidate[last_succ_gn].append( + buf + ) + + return group_n_to_bufs_after_swap_dealloc_instead_of_candidate + + +def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], -) -> list[BaseSchedulerNode]: +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + + curr: Optional[BaseSchedulerNode] = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + debug_num_sink_waits_to_reorder: Optional[int] = ( + config_comms.sink_waits_iterative_debug_limit_to_sink + ) + + iterative_recompute_error = False + while curr is not None and _prev[curr] is not None: + _prev_curr = _prev[curr] + if iterative_recompute_error: + break + if ( + debug_num_sink_waits_to_reorder is not None + and len(processed_waits) >= debug_num_sink_waits_to_reorder + ): + break + + # pyrefly: ignore [bad-argument-type] + if not (contains_wait(curr) and curr not in processed_waits): + curr = _prev_curr + continue + + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_colls = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] + + while candidate is not None: + if config_comms.sink_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.sink_iterative_extra_comm_comp_overlap * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + wait_snode.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in candidate.unmet_dependencies if not _is_fake_dep(d) + } + + group_outs = group.get_outputs() + data_dep = None + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + # Conservative sink wait, limiting by space before next collective. + # The global strategy is that bucketing should create space. + # For 2D we can experiment with allowing to sink Wait beyond non current group collective. + # pyrefly: ignore[unbound-name] + if not config_comms.sink_waits_iterative_swap_with_collectives: + if contains_async_collective(candidate): + info.limiting_factor = ( + f"candidate contains_async_collective {candidate.get_name()}" + ) + break + + # 1. If we have data_dep - we can not swap => trying to group + # 2. If swap candidate and current node both contain collectives => trying to group + if data_dep is not None or ( + both_contain_comms := ( + contains_collective(group) and contains_collective(candidate) + ) + ): + _is_groupable, groupable_reason = _is_node_groupable_for_sink_waits( + candidate + ) + if _is_groupable: + group_tail = candidate + if ( + # pyrefly: ignore[unbound-name] + config_comms.sink_iterative_use_runtime_estimations + and contains_collective(candidate) + ): + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + group_colls[candidate] = (comm_time, comp_time) + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _next[candidate] + continue + elif data_dep is None: + if ( + # pyrefly: ignore[unbound-name] + not config_comms.sink_waits_iterative_unsafe_collectives_reorder + and both_contain_comms + ): + info.limiting_factor = ( + f"collective ordering {_group_names(gns)}" + f"\n with candidate:{candidate.get_name()}" + ) + break + else: + info.limiting_factor = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"\n dep on {_group_names(gns)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{groupable_reason}" + ) + break + + # pyrefly: ignore[unbound-name] + if config_comms.sink_iterative_use_runtime_estimations: + if is_wait(candidate.node): + # Corresponding collective is before the group, + # Swap can increase exposed time of corresponding collective + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), runtimes + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + # We do not know how much we can sink more after this swap, + # Just comparing advantage at the moment for now. + if exposed_after > exposed_before: + info.limiting_factor = ( + "candidate is wait," + f" exposed_before:{exposed_before} vs exposed_after:{exposed_after}" + ) + break + + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + # If candidate has sync runtime, + # Waits of gorup_colls are on the right from group. + # Swap can increase their exposed time. + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_colls) > 0: + # Advantage for current Wait to do the Swap + # pyrefly: ignore[no-matching-overload] + exposed_delta = max( + 0, + info.comm_time - info.comp_time, + ) + # pyrefly: ignore[no-matching-overload] + -max(0, info.comm_time - info.comp_time - c_runtime) + for gc, (gc_comm_time, gc_comp_time) in group_colls.items(): + exposed_delta += max(0, gc_comm_time - gc_comp_time) - max( + 0, gc_comm_time - gc_comp_time + c_runtime + ) + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}, group contains collectives," + f" total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gc, ( + gc_comm_time, + gc_comp_time, + ) in group_colls.items(): + group_colls[gc] = ( + gc_comm_time, + gc_comp_time - c_runtime, + ) + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # [group] candidate -> candidate [group] + # Check for buffers with successors in group and candidate last successor + # + # Buf that changes its last use snode, + # It was deallocated by candidate, + # but after swap it will be deallocated by group node. + group_n_to_bufs_after_swap_dealloc_instead_of_candidate = ( + _find_buffers_with_changed_last_use_sink_waits( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update, _size_free_delta_update = ( + _calculate_potential_peak_memory_sink_waits( + candidate, + gns, + group_head, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _curr_memory, + snodes_allocfree, + ) + ) + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.sink_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + + info.moves += 1 + info.moves_info += f"+{candidate.get_name()}" + + _head = _perform_double_linked_list_swap_sink_waits( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + _update_memory_tracking_after_swap_sink_waits( + candidate, + gns, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _post_alloc_update, + _size_free_delta_update, + _curr_memory, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "sink_waits_iterative", + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + ) + if iterative_recompute_error: + break + + candidate = _next[group_tail] + curr = _prev_curr + + new_snodes = _format_and_log_sink_waits_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + + return new_snodes, stats + + +def sink_waits_iterative(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Similarly to reorder_communication_preserving_peak_memory this pass will try to iteratively + push Wait nodes later, recomputing estimated peak memory before each swap, + and preventing peak memory regressions. + + Pass will be applied to every Wait node. If there are immediate dependencies with next node, + pass will try to group them together and on the next step to swap the group with next candidate. + + If _inductor.config_comms.sink_iterative_use_runtime_estimations is set True, + pass will stop reordering of Wait once corresponding Collective is unexposed, + based on runtime estimations. + + inductor.config_comms.sink_iterative_peak_memory_budget allows to tune how much pass + can regress initial peak memory. + E.g.: + sink_iterative_peak_memory_budget == 0.0 - No regression of initial peak memory is allowed + sink_iterative_peak_memory_budget == 0.2 - Pass can improve comm-compute overlap, sacrificing + 20% of initial peak memory value. + + inductor.config_comms.sink_iterative_extra_comm_comp_overlap config allows to more aggressively + sink waits, stopping only when overlap_compute >= (1 + extra_comm_comp_overlap) * comm_time + """ return _sink_waits_iterative_internal(snodes)[0] def estimate_op_runtime(snode: BaseSchedulerNode) -> float: """ - Returns estimated op runtime in nanoseconds (ns) + Returns estimated op runtime in milliseconds (ms) """ if config.estimate_op_runtime == "default": runtime = snode.get_estimated_runtime() diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py index 51242c7f2cf5b..31f38b867dd5e 100644 --- a/torch/_inductor/config_comms.py +++ b/torch/_inductor/config_comms.py @@ -46,17 +46,26 @@ # when overlap_comp >= (1 + extra_overlap_ratio) * comm_time # Allows to configure more aggressive overlap reorder_iterative_extra_comm_comp_overlap: float = 0.5 +# The sink waits reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive sink waits +sink_iterative_extra_comm_comp_overlap: float = 0.5 # Allow reorder iterative pass to increase peak memory # up to peak_memory_before_pass * (1 + budget) reorder_iterative_peak_memory_budget: float = 0.2 +# Allow sink waits iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +sink_iterative_peak_memory_budget: float = 0.2 # Experimental unsafe configuration that allows changing relative collectives order. # Must be used with runtime_estimations_align_across_all_distributed_ranks = True reorder_iterative_unsafe_collectives_reorder: bool = True +sink_waits_iterative_unsafe_collectives_reorder: bool = True # Allow group and move other collectives during reordering reorder_iterative_group_with_collectives: bool = False +sink_waits_iterative_swap_with_collectives: bool = False # adds patch, save_config, etc install_config_module(sys.modules[__name__]) From 3fdc5dbf1d1742ed49aeebc190db28835fd6ddbf Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 6 Nov 2025 07:24:19 -0800 Subject: [PATCH 129/651] Make CUDA preload logic more straightforward (#167046) I.e. remove distinction between two cases, and always preload full set of libraries For some reason, when one uses `virtualenv` instead of `venv`, preloading `cudart` works, but it fails to find cudnn or cublasLT later on Fix it, by getting read of partial preload logic for one of the cases and always preload full set of libraries Test plan on stock Ubuntu: ``` pip install virtualenv virtualenv --symlinks -p python3.11 --prompt virtv venv-virt source venv-virt/bin/activate pip install torch python -c 'import torch' ``` Fixes https://github.com/pytorch/pytorch/issues/165812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167046 Approved by: https://github.com/atalman --- torch/__init__.py | 73 +++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 05a34bdd93200..b64961a9c56f6 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -303,8 +303,8 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: return nvidia_lib_paths + lib_paths -def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] - """Preloads cuda deps if they could not be found otherwise.""" +def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] + """Preloads cuda library if it could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed assert platform.system() == "Linux", "Should only be called on Linux" @@ -320,6 +320,39 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) +def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: + cuda_libs: dict[str, str] = { + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "nvjitlink": "libnvJitLink.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "cusparselt": "libcusparseLt.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + "nvshmem": "libnvshmem_host.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", + } + + # If error is passed, re-raise it if it's not about one of the abovementioned + # libraries + if err is not None and [ + lib for lib in cuda_libs.values() if lib.split(".", 1)[0] in err.args[0] + ]: + raise err + + # Otherwise, try to preload dependencies from site-packages + for lib_folder, lib_name in cuda_libs.items(): + _preload_cuda_lib(lib_folder, lib_name) + + # libnvToolsExt is Optional Dependency + _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + + # See Note [Global dependencies] def _load_global_deps() -> None: if platform.system() == "Windows": @@ -346,43 +379,15 @@ def _load_global_deps() -> None: # libtorch_global_deps.so always depends in cudart, check if its installed and loaded if "libcudart.so" not in _maps: return - # If all above-mentioned conditions are met, preload nvrtc and nvjitlink - _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") - _preload_cuda_deps("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]") - _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") + # If all above-mentioned conditions are met, preload CUDA dependencies + _preload_cuda_deps() except Exception: pass except OSError as err: - # Can only happen for wheel with cuda libs as PYPI deps + # Can happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is - cuda_libs: dict[str, str] = { - "cublas": "libcublas.so.*[0-9]", - "cudnn": "libcudnn.so.*[0-9]", - "cuda_nvrtc": "libnvrtc.so.*[0-9]", - "cuda_runtime": "libcudart.so.*[0-9]", - "cuda_cupti": "libcupti.so.*[0-9]", - "cufft": "libcufft.so.*[0-9]", - "curand": "libcurand.so.*[0-9]", - "nvjitlink": "libnvJitLink.so.*[0-9]", - "cusparse": "libcusparse.so.*[0-9]", - "cusparselt": "libcusparseLt.so.*[0-9]", - "cusolver": "libcusolver.so.*[0-9]", - "nccl": "libnccl.so.*[0-9]", - "nvshmem": "libnvshmem_host.so.*[0-9]", - "cufile": "libcufile.so.*[0-9]", - } - - is_cuda_lib_err = [ - lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] - ] - if not is_cuda_lib_err: - raise err - for lib_folder, lib_name in cuda_libs.items(): - _preload_cuda_deps(lib_folder, lib_name) - - # libnvToolsExt is Optional Dependency - _preload_cuda_deps("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + _preload_cuda_deps(err) ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) From bfc0ba4af97a1169d6fee5692fae34051b750a12 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Thu, 6 Nov 2025 16:50:12 +0000 Subject: [PATCH 130/651] `nn.Linear`: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (#166071) As per title. It seems safe to be able to generalize to arbitrary contiguous inputs since `at::matmul` is likely to do the flattening to avoid `baddmm`. Additionally, we guard for bias to be 1D and contiguous which is guaranteed to be fused with no copies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166071 Approved by: https://github.com/ngimel --- aten/src/ATen/native/Linear.cpp | 61 ++++++++++++++++++++--------- test/profiler/test_profiler_tree.py | 6 +-- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 1da245972f0cb..fbabba84dbb2d 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() { // `_flatten_nd_linear` flattens all but the last dimension of the input tensor // before passing it to linear operation static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) { - const auto input_sizes = input.sym_sizes(); - // can't use -1 in reshape because it errors when a dimension is 0 - c10::SymInt flattened_dim = 1; - for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) { - flattened_dim = flattened_dim * input_sizes[i]; + const auto input_sizes = input.sym_sizes(); + + const auto result_flattened = [&]() -> Tensor { + const auto input_ncols = input_sizes.back(); + const auto input_flattened_nrows = [&]() -> c10::SymInt { + // can't use -1 in reshape because it errors when a dimension is 0 + auto flattened_nrows = c10::SymInt{1}; + for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) { + flattened_nrows *= size; + } + return flattened_nrows; + }(); + + const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols}); + if (weight.layout() == c10::kStrided) { + return at::addmm(bias, input_flattened, weight.t()); + } else { + // weight is sparse, and addmm for sparse expects matmul lhs to be sparse, + // so we transpose the problem. + // NOTE: at::matmul handles (dense @ sparse) similarly. + const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1); + return at::addmm(bias_t, weight, input_flattened.t()).t(); } - auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)}); - const auto result = at::addmm(bias, inp_reshape, weight.t()); - auto new_size = input_sizes.slice(0, input_sizes.size() - 1); - c10::SymDimVector sizes_vec(new_size.begin(), new_size.end()); - sizes_vec.push_back(result.sym_size(1)); - return result.view_symint(sizes_vec); + }(); + + // Unflatten flattened row dims + auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()}; + result_sizes.back() = result_flattened.sym_size(1); + return result_flattened.view_symint(result_sizes); } @@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { - // Also hit the fused path for contiguous 3D input, if not using xla + + const auto is_bias_likely_fusable = ( + bias->defined() && + // cuBLASLt: will fuse in the epilogue without copies + // when input/weight/bias are all strided. + // When weight is not strided, bias will not be fused, + // but we can still dispatch here to avoid at::matmul + // path which will probably use a very similar + // flattening optimization. + ((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false()) + ); + if (is_bias_likely_fusable && !input.is_xla()) { + // Also hit the fused path for contiguous nD input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. - bool is_contiguous = input.is_contiguous_or_false(); - if (is_contiguous && input_dim == 3) { - return _flatten_nd_linear(input, weight, *bias); - } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { + if (input.is_contiguous_or_false()) { return _flatten_nd_linear(input, weight, *bias); - } else if (parseLinearFlatten3d() && input_dim == 3) { + } else if (parseLinearFlatten3d()) { // If user forces flattening via env var const Tensor input_cont = input.contiguous(); return _flatten_nd_linear(input_cont, weight, *bias); diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index c6316fe3cd7e3..e8d28d7eff032 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -624,8 +624,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided @@ -671,8 +670,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided From 41c9eeecec2360edd2c08239cb6c7b62b5eea123 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 6 Nov 2025 17:14:26 +0000 Subject: [PATCH 131/651] Update Sphinx dependencies (#164901) This pull request updates the PyTorch documentation build system to support newer versions of Sphinx and its related dependencies, improves coverage checking for undocumented objects, and adds configuration enhancements to the docs build. The most important changes are grouped below. **Dependency Upgrades and Compatibility:** * Upgraded `sphinx` to version 7.2.6 and updated related documentation dependencies (`breathe`, `exhale`, `docutils`, `myst-nb`, `sphinx-design`, `myst-parser`, and others) in `.ci/docker/requirements-docs.txt` to ensure compatibility with Python 3.13 and improve documentation generation. [[1]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L1-L12) [[2]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L39-R45) [[3]](diffhunk://#diff-b5577a8e38a2e4c5d91865096b259738cc1dbcb97921abb73045dae0255b1479L59-R64) * Replaced the editable install of `pytorch_sphinx_theme2` with a pinned version for stability in documentation builds. **Documentation Coverage and Build Improvements:** * Updated the coverage check logic in `.ci/pytorch/python_doc_push_script.sh` to parse the new Sphinx 7.2.6+ coverage report format, extracting the undocumented count from the statistics table for more reliable coverage validation. **Configuration and Formatting Enhancements:** * Introduced `autosummary_filename_map` in `docs/source/conf.py` to resolve duplicated autosummary output filenames for functions and classes with the same name, improving documentation clarity. **Minor Documentation Formatting:** * Removed an unused `:template:` directive from `docs/source/quantization-support.md` for cleaner autosummary output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164901 Approved by: https://github.com/albanD --- .ci/docker/requirements-docs.txt | 34 ++++++++++------------- .ci/pytorch/python_doc_push_script.sh | 40 +++++++++++++++++++-------- docs/source/conf.py | 40 +++++++++++++++++++++++++++ docs/source/quantization-support.md | 1 - 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 6e623b4c56949..de71919012e13 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -1,15 +1,11 @@ -sphinx==5.3.0 +sphinx==7.2.6 #Description: This is used to generate PyTorch docs -#Pinned versions: 5.3.0 +#Pinned versions: 7.2.6 -standard-imghdr==3.13.0; python_version >= "3.13" -#Description: This is needed by Sphinx, so it needs to be added here. -# The reasons are as follows: -# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr); -# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13. -# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency. +pytorch_sphinx_theme2==0.2.0 +#Description: This is needed to generate PyTorch docs +#Pinned versions: 0.2.0 --e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2 # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering # but it doesn't seem to work and hangs around idly. The initial thought that it is probably # something related to Docker setup. We can investigate this later. @@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13" #Description: This is used to generate PyTorch docs #Pinned versions: 2.13.0 -breathe==4.34.0 +breathe==4.36.0 #Description: This is used to generate PyTorch C++ docs -#Pinned versions: 4.34.0 +#Pinned versions: 4.36.0 -exhale==0.2.3 +exhale==0.3.7 #Description: This is used to generate PyTorch C++ docs -#Pinned versions: 0.2.3 +#Pinned versions: 0.3.7 -docutils==0.16 +docutils==0.20 #Description: This is used to generate PyTorch C++ docs -#Pinned versions: 0.16 +#Pinned versions: 0.20 bs4==0.0.1 #Description: This is used to generate PyTorch C++ docs @@ -56,13 +52,13 @@ IPython==8.12.0 #Description: This is used to generate PyTorch functorch docs #Pinned versions: 8.12.0 -myst-nb==0.17.2 +myst-nb==1.3.0 #Description: This is used to generate PyTorch functorch and torch.compile docs. -#Pinned versions: 0.17.2 +#Pinned versions: 1.3.0 # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs python-etcd==0.4.5 sphinx-copybutton==0.5.0 -sphinx-design==0.4.0 +sphinx-design==0.6.1 sphinxcontrib-mermaid==1.0.0 -myst-parser==0.18.1 +myst-parser==4.0.1 diff --git a/.ci/pytorch/python_doc_push_script.sh b/.ci/pytorch/python_doc_push_script.sh index ec1187b3fe4c4..6bcd46c4815a6 100755 --- a/.ci/pytorch/python_doc_push_script.sh +++ b/.ci/pytorch/python_doc_push_script.sh @@ -89,23 +89,41 @@ if [ "$is_main_doc" = true ]; then make coverage # Now we have the coverage report, we need to make sure it is empty. - # Count the number of lines in the file and turn that number into a variable - # $lines. The `cut -f1 ...` is to only parse the number, not the filename - # Skip the report header by subtracting 2: the header will be output even if - # there are no undocumented items. + # Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row + # showing the undocumented count in the third column. + # Example: | TOTAL | 99.83% | 2 | # # Also: see docs/source/conf.py for "coverage_ignore*" items, which should # be documented then removed from there. - lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ') - undocumented=$((lines - 2)) - if [ $undocumented -lt 0 ]; then + + # Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table + # The table format is: | Module | Coverage | Undocumented | + # Extract the third column (undocumented count) from the TOTAL row + undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ') + + if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then echo coverage output not found exit 1 - elif [ $undocumented -gt 0 ]; then - echo undocumented objects found: - cat build/coverage/python.txt + elif [ "$undocumented" -gt 0 ]; then + set +x # Disable command echoing for cleaner output + echo "" + echo "=====================" + echo "UNDOCUMENTED OBJECTS:" + echo "=====================" + echo "" + # Find the line number of the TOTAL row and print only what comes after it + total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1) + if [ -n "$total_line" ]; then + # Print only the detailed list (skip the statistics table) + tail -n +$((total_line + 2)) build/coverage/python.txt + else + # Fallback to showing entire file if TOTAL line not found + cat build/coverage/python.txt + fi + echo "" echo "Make sure you've updated relevant .rsts in docs/source!" - echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'" + echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'" + set -x # Re-enable command echoing exit 1 fi else diff --git a/docs/source/conf.py b/docs/source/conf.py index b5a04df3e090b..9a06c0e2036d2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -206,6 +206,41 @@ os.path.join(os.path.dirname(pytorch_sphinx_theme2.__file__), "templates"), ] # TODO: document these and remove them from here. +# Fixes the duplicated +autosummary_filename_map = { + "torch.nn.utils.prune.identity": "torch.nn.utils.prune.identity_function", + "torch.nn.utils.prune.Identity": "torch.nn.utils.prune.Identity_class", + "torch.optim.adamw.adamw": "torch.optim.adamw.adamw_function", + "torch.optim.adamw.AdamW": "torch.optim.adamw.AdamW_class", + "torch.optim.asgd.asgd": "torch.optim.asgd.asgd_function", + "torch.optim.asgd.ASGD": "torch.optim.asgd.ASGD_class", + "torch.optim.nadam.nadam": "torch.optim.nadam.nadam_function", + "torch.optim.nadam.NAdam": "torch.optim.nadam.NAdam_class", + "torch.optim.radam.radam": "torch.optim.radam.radam_function", + "torch.optim.radam.RAdam": "torch.optim.radam.RAdam_class", + "torch.optim.rmsprop.rmsprop": "torch.optim.rmsprop.rmsprop_function", + "torch.optim.rmsprop.RMSprop": "torch.optim.rmsprop.RMSprop_class", + "torch.optim.rprop.rprop": "torch.optim.rprop.rprop_function", + "torch.optim.rprop.Rprop": "torch.optim.rprop.Rprop_class", + "torch.optim.sgd.sgd": "torch.optim.sgd.sgd_function", + "torch.optim.sgd.SGD": "torch.optim.sgd.SGD_class", + "torch.optim.adadelta.adadelta": "torch.optim.adadelta.adadelta_function", + "torch.optim.adadelta.Adadelta": "torch.optim.adadelta.Adadelta_class", + "torch.optim.adagrad.adagrad": "torch.optim.adagrad.adagrad_function", + "torch.optim.adagrad.Adagrad": "torch.optim.adagrad.Adagrad_class", + "torch.optim.adam.adam": "torch.optim.adam.adam_function", + "torch.optim.adam.Adam": "torch.optim.adam.Adam_class", + "torch.optim.adamax.adamax": "torch.optim.adamax.adamax_function", + "torch.optim.adamax.Adamax": "torch.optim.adamax.Adamax_class", + "torch.mtia.stream": "torch.mtia.stream_function", + "torch.mtia.Stream": "torch.mtia.Stream_class", + "torch.cpu.stream": "torch.cpu.stream_function", + "torch.cpu.Stream": "torch.cpu.Stream_class", + "torch.cuda.stream": "torch.cuda.stream_function", + "torch.cuda.Stream": "torch.cuda.Stream_class", + "torch.xpu.stream": "torch.xpu.stream_function", + "torch.xpu.Stream": "torch.xpu.Stream_class", +} coverage_ignore_functions = [ # torch @@ -3195,6 +3230,11 @@ def linkcode_resolve(domain, info): # Enable overriding of function signatures in the first line of the docstring. autodoc_docstring_signature = True +# Exclude inherited IntEnum methods that have RST formatting issues in their docstrings +autodoc_default_options = { + "exclude-members": "from_bytes, to_bytes", +} + # -- katex javascript in header # # def setup(app): diff --git a/docs/source/quantization-support.md b/docs/source/quantization-support.md index 986b1cb257513..3bb5c45face69 100644 --- a/docs/source/quantization-support.md +++ b/docs/source/quantization-support.md @@ -253,7 +253,6 @@ regular full-precision tensor. .. autosummary:: :toctree: generated :nosignatures: - :template: classtemplate.rst view as_strided From fd7bf9ce1021a42ab5c3839152b3072a4b2bb25c Mon Sep 17 00:00:00 2001 From: karthickai Date: Wed, 5 Nov 2025 17:28:59 -0800 Subject: [PATCH 132/651] [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890) When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error. Fixes: #166888 #163674 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890 Approved by: https://github.com/eellison, https://github.com/mlazos --- test/inductor/test_torchinductor.py | 14 +++++++++ test/test_torchfuzz_repros.py | 39 ------------------------- torch/_inductor/codecache.py | 6 ++++ torch/_inductor/codegen/common.py | 11 +++++-- torch/_inductor/codegen/triton_utils.py | 5 ++++ 5 files changed, 34 insertions(+), 41 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed8993a1c9a39..d0ff5799ac417 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14424,6 +14424,20 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) + @skip_if_halide + @requires_cuda_and_triton + def test_unbacked_float_item(self): + def fn(x, max_val): + return torch.clamp(x, 0, max_val.item()) + + self.common( + fn, + ( + torch.randn(10, 20, 30, device=self.device), + torch.tensor(5.0, device=self.device), + ), + ) + # end of class CommonTemplate - add new tests here diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 988bcf8de273c..c278378e2cc4a 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -422,45 +422,6 @@ def foo(arg0): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163674") - def test_fuzzer_issue_163674(self): - torch.manual_seed(0) - - def foo(arg0, arg1, arg2): - t0 = arg0 # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda - t1 = t0.clone() - t1.zero_() # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float16, device=cuda - t2 = arg1 # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda - t3 = arg2 # size=(), stride=(), dtype=float32, device=cuda - t4 = t2.clone() - t4.fill_( - t3.item() - ) # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda - t5 = torch.pow( - t1, t4 - ) # size=(79488, 1, 3, 1), stride=(3, 3, 1, 1), dtype=float32, device=cuda - t6 = t5.reshape( - (96, 69, 36) - ) # size=(96, 69, 36), stride=(2484, 36, 1), dtype=float32, device=cuda - output = t6 - return output - - arg0 = torch.rand( - [79488, 1, 3, 1], dtype=torch.float16, device="cuda", requires_grad=True - ) - arg1 = torch.rand( - [79488, 1, 3, 1], dtype=torch.float32, device="cuda", requires_grad=True - ) - arg2 = torch.rand([], dtype=torch.float32, device="cuda", requires_grad=True) - - out_eager = foo(arg0, arg1, arg2) - out_eager.sum().backward() - print("Eager Success! ✅") - compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True) - out_compiled = compiled_foo(arg0, arg1, arg2) - out_compiled.sum().backward() - print("Compile Success! ✅") - if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 9583494299265..0177f6900c611 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache): throw std::runtime_error("expected int arg"); return reinterpret_cast(result); }} + template <> inline float parse_arg(PyObject* args, size_t n) {{ + auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1.0 && PyErr_Occurred())) + throw std::runtime_error("expected float arg"); + return static_cast(result); + }} {extra_parse_arg} diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 730c03f1c813c..3e9f174c810c5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1732,9 +1732,15 @@ def cpp_argdefs( call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") + if isinstance(outer, sympy.Symbol) and symbol_is_type( + outer, (SymT.UNBACKED_FLOAT) + ): + arg_defs.append(f"const float {inner}") + arg_types.append("const float") + else: + arg_defs.append(f"const {INDEX_TYPE} {inner}") + arg_types.append(f"const {INDEX_TYPE}") call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert not self.workspace_args, "Workspace not supported on CPU " @@ -2353,6 +2359,7 @@ def rename_indexing( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, + SymT.UNBACKED_FLOAT, ), ) } diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2a2706ad5720b..75a34813c876b 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -4,6 +4,7 @@ import sympy import torch +from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config from ..runtime.hints import AttrsDescriptorWrapper @@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" + elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( + arg.expr, (SymT.UNBACKED_FLOAT) + ): + return "fp32" elif isinstance(arg.expr, bool): return "i1" From 03fd2b796e6d3469da23504365a20f4e91ff9117 Mon Sep 17 00:00:00 2001 From: VieEeEw <34191413+VieEeEw@users.noreply.github.com> Date: Thu, 6 Nov 2025 09:16:29 -0800 Subject: [PATCH 133/651] [Flight Recorder] Reverted to include stack traces for dump pipe triggered FR dump (#167023) [Flight Recorder] Reverted to include stack traces for dump pipe triggered FR dump (#167023) Summary: We should also retry if include stacktraces failed. Changed was introduced in https://github.com/pytorch/pytorch/pull/164591 Test Plan: eyes Reviewed By: fduwjj Differential Revision: D86248484 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 29ccc115cc94d..ccb1e7466157b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1883,7 +1883,7 @@ void ProcessGroupNCCL::HeartbeatMonitor::runLoop() { LOG(INFO) << pg_->logPrefix() << "Dump signal received through pipe, triggering FR dump."; futures.emplace_back(std::async(std::launch::async, [this, onlyActive]() { - return this->pg_->dumpDebuggingInfo(false, onlyActive); + return this->pg_->dumpDebuggingInfo(true, onlyActive); })); } } From 0ed41194205226e40d0835ac6c7db3ee1337032d Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:23:23 +0000 Subject: [PATCH 134/651] [ROCm][CI] Run rocm.yml and inductor-rocm.yml every 3rd hour (#167220) Even after [reducing frequency of rocm.yml and inductor-rocm.yml to per hour](https://github.com/pytorch/pytorch/pull/166870), we are still observing queueing on MI2xx runners as of Nov 6 2025 10:30AM CST: {DFECE929-174D-4EE4-9448-D43AA1AF0B53} We think it's because we had to move the periodic.yml workflow runs to the MI210 runners in light of the Cirrascale runners not being available: https://github.com/pytorch/pytorch/issues/166866. We observe [increased queueing](https://hud.pytorch.org/queue_time_analysis?dateRange=7&startDate=2025-10-30T16%3A00%3A48.381Z&endDate=2025-11-06T16%3A00%3A48.381Z&granularity=hour&chartType=bar&repos=pytorch%2Fpytorch&category=machine_type&machineTypes=linux.rocm.gpu.2&items=linux.rocm.gpu.2) after the point where we added periodic jobs to the MI210 runners. linux rocm gpu 2_queueing This PR temproarily changes the rocm.yml and inductor-rocm.yml workflows to run on a 3-hourly basis rather than every hour, until the Cirrascale outage is resolved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167220 Approved by: https://github.com/jeffdaily --- .github/workflows/inductor-rocm.yml | 2 +- .github/workflows/rocm.yml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index b2ff53a645481..8dbc785e20f16 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -2,7 +2,7 @@ name: inductor-rocm on: schedule: - - cron: 0 * * * * + - cron: 0 */3 * * * push: branches: - release/* diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index ffe6efbe0433c..6f37d3e4f65a4 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -9,7 +9,8 @@ on: workflow_dispatch: schedule: - cron: 29 8 * * * # about 1:29am PDT - - cron: 0 * * * * + - cron: 0 */3 * * * + concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} From ea7add4837bbb4295a02a16e7becaa390b2ea729 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 5 Nov 2025 15:05:07 -0800 Subject: [PATCH 135/651] fix static_input_indices subclass remapping under training (#167127) We have some logic figure out "given which inputs have static indices in the pre-subclass-desugaring graph, figure out the static indices in the post-subclass-desugaring graph", and it was busted for training. Separately, we should probably not have to do this logic at all - as @eellison mentioned, inputs/outputs in the graph are less likely to be tweaked through graph passes, so it would be more convenient and less hassle if we just stashed if a given input was static directly on the Descriptor for it. I did not end up doing that in this PR though. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167127 Approved by: https://github.com/ezyang --- test/dynamo/test_subclasses.py | 40 +++++++++++++++++++ .../_aot_autograd/graph_capture_wrappers.py | 9 +++-- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 5d31fa28880a6..25c0da48f602f 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -2169,6 +2169,46 @@ def fn(t0, t1, t2): fn(torch.ones(4), x, torch.ones(4)) + @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + def test_subclass_parameters_are_static_under_training(self): + from collections.abc import Callable + from typing import Any, Optional + + from torch._inductor.compile_fx import compile_fx + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + def inner_compile( + gm: torch.fx.GraphModule, + example_inputs: list[torch.Tensor], + cudagraphs: Optional[BoxedBool] = None, + static_input_idxs: Optional[list[int]] = None, + is_backward: bool = False, + graph_id: Optional[int] = None, + cpp_wrapper: bool = False, + aot_mode: bool = False, + is_inference: bool = False, + boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, + layout_opt: Optional[bool] = None, + extern_node_serializer: Optional[Callable[[list[Any]], Any]] = None, + ): + # Important bit: there are 3 params: linear.weight.a, linear.weight.b, linear.bias, + # which are the first 3 args of the graph. + self.assertEqual(static_input_idxs, [0, 1, 2]) + return gm + + compiler = functools.partial(compile_fx, inner_compile=inner_compile) + + mod = torch.nn.Linear(4, 4) + w_a = torch.randn(4, 4) + w_b = torch.randn(4, 4) + w = torch.nn.Parameter(TwoTensor(w_a, w_b).requires_grad_()) + mod.weight = w + + mod = torch.compile(mod, backend=compiler) + + mod(torch.randn(4)) + # copied from common_utils.py::NestedTensorTestCase def assertEqualIgnoringNestedInts(self, a, b): # unbinding NJTs allows us to compare them as essentially equal without diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index d81b0e9d0bd24..bc4dc87ddeced 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -1272,15 +1272,18 @@ def inner_fw_only(*args): args_unwrapped = (primals_unwrapped_pair[0], tangents_unwrapped_pair[0]) args_descs_unwrapped = (primals_unwrapped_pair[1], tangents_unwrapped_pair[1]) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args[0], meta.static_input_indices + ) else: args_unwrapped, args_descs_unwrapped = unwrap_tensor_subclasses( # type: ignore[assignment] args, # type: ignore[arg-type] args_descs, # type: ignore[arg-type] append_symints=True, ) - remapped_static_indices = remap_unwrapped_subclass_arg_indices( - args, meta.static_input_indices - ) + remapped_static_indices = remap_unwrapped_subclass_arg_indices( + args, meta.static_input_indices + ) if is_joint_structure: primals_unwrapped = args_unwrapped[0] # type: ignore[assignment] From 73078f305fb91d8f3e0101aac1c8eccfc98be839 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 6 Nov 2025 17:55:19 +0000 Subject: [PATCH 136/651] Add missing super().setUp() (#167163) In a trunk failure today, we saw the same test running on both trunk and slow shards. The reason is that this test didn't invoke `super().setUp()`, so all the test features like slow and disabled test didn't apply to them. I use Claude to find all test classes with a `setUp()` method that didn't called `super().setUp()` and patch all of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167163 Approved by: https://github.com/malfet --- test/ao/sparsity/test_scheduler.py | 1 + test/backends/xeon/test_launch.py | 1 + test/custom_backend/test_custom_backend.py | 1 + test/custom_operator/test_custom_ops.py | 1 + .../checkpoint/_experimental/test_builder.py | 1 + .../_experimental/test_checkpoint_process.py | 1 + .../_experimental/test_checkpoint_reader.py | 1 + .../_experimental/test_checkpoint_writer.py | 1 + .../_experimental/test_checkpointer.py | 2 ++ .../checkpoint/_experimental/test_staging.py | 1 + .../checkpoint/test_quantized_hf_storage.py | 1 + .../elastic/multiprocessing/test_api.py | 1 + test/distributed/launcher/test_api.py | 1 + test/distributed/test_c10d_gloo.py | 1 + test/distributed/test_c10d_nccl.py | 1 + test/distributed/test_run.py | 1 + test/distributed/test_serialization.py | 1 + test/distributed/test_store.py | 1 + test/export/test_lift_unlift.py | 2 ++ test/export/test_sparse.py | 2 +- test/export/test_upgrader.py | 1 + test/functorch/dim/test_getsetitem.py | 1 + test/functorch/test_ac_logging.py | 1 + test/fx/test_fx_split_node_finder.py | 1 + test/fx/test_graph_pickler.py | 2 +- test/fx/test_net_min_base.py | 1 + test/inductor/test_augmented_graph_helper.py | 1 + test/inductor/test_compile_subprocess.py | 2 +- test/inductor/test_mkldnn_pattern_matcher.py | 2 +- test/inductor/test_ordered_set.py | 24 +++++++++++++++++++ .../test_torchinductor_dynamic_shapes.py | 2 +- test/onnx/exporter/test_building.py | 1 + ...st_registraion.py => test_registration.py} | 1 + test/profiler/test_cpp_thread.py | 2 ++ test/test_cuda_primary_ctx.py | 1 + test/test_cuda_sanitizer.py | 2 ++ test/test_cuda_trace.py | 1 + test/test_monitor.py | 1 + test/test_nnapi.py | 1 + test/test_sparse.py | 2 +- test/test_torchfuzz_repros.py | 1 + test/test_utils.py | 1 + test/test_weak.py | 1 + 43 files changed, 70 insertions(+), 6 deletions(-) rename test/onnx/internal/{test_registraion.py => test_registration.py} (99%) diff --git a/test/ao/sparsity/test_scheduler.py b/test/ao/sparsity/test_scheduler.py index 0477b70fd8783..e7d4e8df90be5 100644 --- a/test/ao/sparsity/test_scheduler.py +++ b/test/ao/sparsity/test_scheduler.py @@ -75,6 +75,7 @@ def test_lambda_scheduler(self): class TestCubicScheduler(TestCase): def setUp(self): + super().setUp() self.model_sparse_config = [ {"tensor_fqn": "0.weight", "sparsity_level": 0.8}, {"tensor_fqn": "2.weight", "sparsity_level": 0.4}, diff --git a/test/backends/xeon/test_launch.py b/test/backends/xeon/test_launch.py index bab1006015212..311f622fafda2 100644 --- a/test/backends/xeon/test_launch.py +++ b/test/backends/xeon/test_launch.py @@ -11,6 +11,7 @@ @unittest.skipIf(not IS_LINUX, "Only works on linux") class TestTorchrun(TestCase): def setUp(self): + super().setUp() self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) def tearDown(self): diff --git a/test/custom_backend/test_custom_backend.py b/test/custom_backend/test_custom_backend.py index 5faf5fa94d5ad..269cc98418c86 100644 --- a/test/custom_backend/test_custom_backend.py +++ b/test/custom_backend/test_custom_backend.py @@ -11,6 +11,7 @@ class TestCustomBackend(TestCase): def setUp(self): + super().setUp() # Load the library containing the custom backend. self.library_path = get_custom_backend_library_path() torch.ops.load_library(self.library_path) diff --git a/test/custom_operator/test_custom_ops.py b/test/custom_operator/test_custom_ops.py index 24bc4db520a89..e66ca04ec5c32 100644 --- a/test/custom_operator/test_custom_ops.py +++ b/test/custom_operator/test_custom_ops.py @@ -18,6 +18,7 @@ class TestCustomOperators(TestCase): def setUp(self): + super().setUp() self.library_path = get_custom_op_library_path() ops.load_library(self.library_path) diff --git a/test/distributed/checkpoint/_experimental/test_builder.py b/test/distributed/checkpoint/_experimental/test_builder.py index 9b2ba937eb4fd..64aacaf8c00cc 100644 --- a/test/distributed/checkpoint/_experimental/test_builder.py +++ b/test/distributed/checkpoint/_experimental/test_builder.py @@ -22,6 +22,7 @@ class TestMakeCheckpointer(TestCase): def setUp(self) -> None: + super().setUp() # Create a temporary directory for checkpoints self.temp_dir = tempfile.mkdtemp() diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_process.py b/test/distributed/checkpoint/_experimental/test_checkpoint_process.py index 1220d5f07235b..161dd1a80c3e1 100644 --- a/test/distributed/checkpoint/_experimental/test_checkpoint_process.py +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_process.py @@ -161,6 +161,7 @@ def test_custom_options(self) -> None: class TestCheckpointProcess(TestCase): def setUp(self) -> None: + super().setUp() """Set up common test fixtures.""" self.rank_info = RankInfo( global_world_size=1, diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py b/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py index 88feb0bffee5d..70d1d30facd70 100644 --- a/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_reader.py @@ -14,6 +14,7 @@ class TestCheckpointReader(TestCase): def setUp(self): + super().setUp() # Create a temporary directory for test checkpoints self.temp_dir = tempfile.mkdtemp() diff --git a/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py b/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py index c5141c6a1730e..959f1c9e7572d 100644 --- a/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py +++ b/test/distributed/checkpoint/_experimental/test_checkpoint_writer.py @@ -52,6 +52,7 @@ def test_custom_values(self): class TestCheckpointWriter(TestCase): def setUp(self): + super().setUp() # Create a temporary directory for test checkpoints self.temp_dir = tempfile.mkdtemp() diff --git a/test/distributed/checkpoint/_experimental/test_checkpointer.py b/test/distributed/checkpoint/_experimental/test_checkpointer.py index 62fde0b3166df..fbd19ff9eafad 100644 --- a/test/distributed/checkpoint/_experimental/test_checkpointer.py +++ b/test/distributed/checkpoint/_experimental/test_checkpointer.py @@ -52,6 +52,7 @@ class TestCheckpointer(TestCase): """Parameterized tests that work with both sync and async checkpointers.""" def setUp(self): + super().setUp() # Create a temporary directory for checkpoints self.temp_dir = tempfile.mkdtemp() @@ -397,6 +398,7 @@ class TestAsyncCheckpointerSpecific(TestCase): """Tests specific to AsyncCheckpointer functionality.""" def setUp(self): + super().setUp() # Create a temporary directory for checkpoints self.temp_dir = tempfile.mkdtemp() diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index 3fdb3bc022f25..5c4a1733fde03 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -12,6 +12,7 @@ class TestDefaultStager(TestCase): def setUp(self) -> None: + super().setUp() # Create a test state dictionary with various data types self.state_dict = { "model": torch.nn.Linear(10, 5).state_dict(), diff --git a/test/distributed/checkpoint/test_quantized_hf_storage.py b/test/distributed/checkpoint/test_quantized_hf_storage.py index c8ee756aaf3f4..da15cff68018c 100644 --- a/test/distributed/checkpoint/test_quantized_hf_storage.py +++ b/test/distributed/checkpoint/test_quantized_hf_storage.py @@ -15,6 +15,7 @@ class TestQuantizedHfStorage(TestCase): def setUp(self): + super().setUp() """Set up common test fixtures.""" self.temp_dir = tempfile.TemporaryDirectory() self.path = self.temp_dir.name diff --git a/test/distributed/elastic/multiprocessing/test_api.py b/test/distributed/elastic/multiprocessing/test_api.py index 9b145777e1457..109dc5b557d12 100644 --- a/test/distributed/elastic/multiprocessing/test_api.py +++ b/test/distributed/elastic/multiprocessing/test_api.py @@ -21,6 +21,7 @@ class SignalHandlingTest(TestCase): def setUp(self): + super().setUp() # Save original environment variable if it exists self.original_signals_env = os.environ.get( "TORCHELASTIC_SIGNALS_TO_HANDLE", None diff --git a/test/distributed/launcher/test_api.py b/test/distributed/launcher/test_api.py index e6e778fe2ff32..04cc17912cf48 100644 --- a/test/distributed/launcher/test_api.py +++ b/test/distributed/launcher/test_api.py @@ -16,6 +16,7 @@ class LauncherApiTest(TestCase): def setUp(self): + super().setUp() # Save original environment variable if it exists self.original_signals_env = os.environ.get( "TORCHELASTIC_SIGNALS_TO_HANDLE", None diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index ffd48407abd01..3e39cb4cbebbc 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2357,6 +2357,7 @@ def forward(self, x, use_fc3=True): class ReducerTest(TestCase): def setUp(self): + super().setUp() self.file = tempfile.NamedTemporaryFile(delete=False) world_size = 1 self.store = c10d.FileStore(self.file.name, world_size) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index ef7ed5282816f..512808757c40c 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -252,6 +252,7 @@ class ProcessGroupNCCLNoGPUTest(TestCase): MAIN_PROCESS_RANK = 0 def setUp(self): + super().setUp() self.rank = self.MAIN_PROCESS_RANK self.world_size = 1 self.file = tempfile.NamedTemporaryFile(delete=False) diff --git a/test/distributed/test_run.py b/test/distributed/test_run.py index 659241dbcbe99..509c08cbf0c35 100644 --- a/test/distributed/test_run.py +++ b/test/distributed/test_run.py @@ -17,6 +17,7 @@ class RunTest(TestCase): def setUp(self): + super().setUp() # Save original environment variable if it exists self.original_signals_env = os.environ.get( "TORCHELASTIC_SIGNALS_TO_HANDLE", None diff --git a/test/distributed/test_serialization.py b/test/distributed/test_serialization.py index 3adb099aa7a3b..6c1d82b5c18da 100644 --- a/test/distributed/test_serialization.py +++ b/test/distributed/test_serialization.py @@ -25,6 +25,7 @@ def __eq__(self, other: "MyClass") -> bool: class TestSerialization(TestCase): def setUp(self) -> None: + super().setUp() # disable debug asserts self._old_debug = os.environ.get(DEBUG_ENV) os.environ[DEBUG_ENV] = "0" diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index a6b69eeb8b93e..5e063d373ffb5 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -317,6 +317,7 @@ def _create_store(self): class PrefixStoreTest(TestCase): def setUp(self): + super().setUp() # delete is false as FileStore will automatically clean up the file self.file = tempfile.NamedTemporaryFile(delete=False) diff --git a/test/export/test_lift_unlift.py b/test/export/test_lift_unlift.py index af892a96feb5e..4ab1d17ca898f 100644 --- a/test/export/test_lift_unlift.py +++ b/test/export/test_lift_unlift.py @@ -138,6 +138,7 @@ def gen_graph_signature(self) -> ExportGraphSignature: class TestLift(TestCase): def setUp(self): + super().setUp() load_torchbind_test_lib() def test_lift_basic(self): @@ -360,6 +361,7 @@ def forward(self, x): class ConstantAttrMapTest(TestCase): def setUp(self): + super().setUp() load_torchbind_test_lib() def test_dict_api(self): diff --git a/test/export/test_sparse.py b/test/export/test_sparse.py index e94601cd5af87..c8d799a0254b0 100644 --- a/test/export/test_sparse.py +++ b/test/export/test_sparse.py @@ -96,7 +96,7 @@ def forward(self, x): ) class TestSparseProp(TestCase): def setUp(self): - TestCase.setUp(self) + super().setUp() def assertEqualMeta(self, x, y): self.assertIsInstance(x, FakeTensor) diff --git a/test/export/test_upgrader.py b/test/export/test_upgrader.py index 0c36b28750f90..88f4c4e2fa435 100644 --- a/test/export/test_upgrader.py +++ b/test/export/test_upgrader.py @@ -8,6 +8,7 @@ class TestUpgrader(TestCase): def setUp(self) -> None: + super().setUp() # Register example upgraders dynamically torch._C._export.register_example_upgraders() diff --git a/test/functorch/dim/test_getsetitem.py b/test/functorch/dim/test_getsetitem.py index ae7ed0283c753..d91078deafd74 100644 --- a/test/functorch/dim/test_getsetitem.py +++ b/test/functorch/dim/test_getsetitem.py @@ -8,6 +8,7 @@ class TestGetSetItem(TestCase): """Comprehensive tests for first-class dimension indexing operations.""" def setUp(self): + super().setUp() """Set up common test fixtures.""" self.batch, self.height, self.width = dims(3) diff --git a/test/functorch/test_ac_logging.py b/test/functorch/test_ac_logging.py index cb65f028a00f3..4ac195c826545 100644 --- a/test/functorch/test_ac_logging.py +++ b/test/functorch/test_ac_logging.py @@ -13,6 +13,7 @@ class TestAcLogging(TestCase): def setUp(self) -> None: + super().setUp() self.graph: MagicMock = MagicMock(spec=Graph) self.node1: MagicMock = MagicMock(spec=Node) self.node2: MagicMock = MagicMock(spec=Node) diff --git a/test/fx/test_fx_split_node_finder.py b/test/fx/test_fx_split_node_finder.py index a139626968ca5..8916140aa24a3 100644 --- a/test/fx/test_fx_split_node_finder.py +++ b/test/fx/test_fx_split_node_finder.py @@ -27,6 +27,7 @@ def sup_f(x): class TestFxSplitNodeFinder(TestCase): def setUp(self): + super().setUp() self.save_path = sys.path[:] self.tmpdir = tempfile.mkdtemp() sys.path.insert(0, self.tmpdir) diff --git a/test/fx/test_graph_pickler.py b/test/fx/test_graph_pickler.py index ae299140d48a7..d37ebc1108a23 100644 --- a/test/fx/test_graph_pickler.py +++ b/test/fx/test_graph_pickler.py @@ -66,7 +66,7 @@ class GraphPicklerCpuTests(TestCase): class TestGraphPickler(TestCase): def setUp(self): torch._dynamo.reset() - TestCase.setUp(self) + super().setUp() self._stack = contextlib.ExitStack() self._stack.enter_context( diff --git a/test/fx/test_net_min_base.py b/test/fx/test_net_min_base.py index 75382304e1950..7e164e7262902 100644 --- a/test/fx/test_net_min_base.py +++ b/test/fx/test_net_min_base.py @@ -14,6 +14,7 @@ class TestNetMinBaseBlock(TestCase): def setUp(self) -> None: + super().setUp() # Setup test fixtures for each test method class SimpleModule(torch.nn.Module): diff --git a/test/inductor/test_augmented_graph_helper.py b/test/inductor/test_augmented_graph_helper.py index 92dcfa1b37b85..b9406b0cf8550 100644 --- a/test/inductor/test_augmented_graph_helper.py +++ b/test/inductor/test_augmented_graph_helper.py @@ -13,6 +13,7 @@ class TestAugmentedGraphHelper(TestCase): def setUp(self): """Create a simple graph structure for testing.""" + super().setUp() # Create a torch.fx.Graph with multiple nodes self.graph = fx.Graph() diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index dc730e408b706..6c8a3367cd1b5 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -70,7 +70,7 @@ def setUp(self): torch._dynamo.reset() FxCompile._reset_stats() - TestCase.setUp(self) + super().setUp() self._stack = contextlib.ExitStack() self._stack.enter_context( diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 709b1fe7f0798..c135d05f060f1 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -142,7 +142,7 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): class TestPatternMatcherBase(TestCase): def setUp(self): - TestCase.setUp(self) + super().setUp() self.ctx_stack = contextlib.ExitStack() self.ctx_stack.enter_context(config.patch({"freezing": True})) diff --git a/test/inductor/test_ordered_set.py b/test/inductor/test_ordered_set.py index c588018fcf667..debd621b0659c 100644 --- a/test/inductor/test_ordered_set.py +++ b/test/inductor/test_ordered_set.py @@ -57,6 +57,7 @@ class TestJointOps(TestCase): basetype = OrderedSet def setUp(self): + super().setUp() self.word = word = "simsalabim" self.otherword = "madagascar" self.letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -851,6 +852,7 @@ def test_issue_37219(self): class TestBasicOpsEmpty(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "empty OrderedSet" self.values = [] self.OrderedSet = OrderedSet(self.values) @@ -864,6 +866,7 @@ def setUp(self): class TestBasicOpsSingleton(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "unit OrderedSet (number)" self.values = [3] self.OrderedSet = OrderedSet(self.values) @@ -883,6 +886,7 @@ def test_not_in(self): class TestBasicOpsTuple(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "unit OrderedSet (tuple)" self.values = [(0, "zero")] self.OrderedSet = OrderedSet(self.values) @@ -902,6 +906,7 @@ def test_not_in(self): class TestBasicOpsTriple(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "triple OrderedSet" self.values = [0, "zero", operator.add] self.OrderedSet = OrderedSet(self.values) @@ -915,6 +920,7 @@ def setUp(self): class TestBasicOpsString(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "string OrderedSet" self.values = ["a", "b", "c"] self.OrderedSet = OrderedSet(self.values) @@ -931,6 +937,7 @@ def test_repr(self): class TestBasicOpsBytes(TestBasicOps, TestCase): def setUp(self): + super().setUp() self.case = "bytes OrderedSet" self.values = [b"a", b"b", b"c"] self.OrderedSet = OrderedSet(self.values) @@ -947,6 +954,7 @@ def test_repr(self): class TestBasicOpsMixedStringBytes(TestBasicOps, TestCase): def setUp(self): + super().setUp() warnings.simplefilter("ignore", BytesWarning) self.case = "string and bytes OrderedSet" self.values = ["a", "b", b"a", b"b"] @@ -1018,6 +1026,7 @@ def test_constructor(self): class TestBinaryOps(TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((2, 4, 6)) def test_eq(self): # SF bug 643115 @@ -1093,6 +1102,7 @@ def test_sym_difference_non_overlap(self): class TestUpdateOps(TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((2, 4, 6)) def test_union_subset(self): @@ -1181,6 +1191,7 @@ def test_difference_method_call(self): class TestMutate(TestCase): def setUp(self): + super().setUp() self.values = ["a", "b", "c"] self.OrderedSet = OrderedSet(self.values) @@ -1469,6 +1480,7 @@ def test_difference(self): class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((1, 2, 3)) self.other = 19 self.otherIsIterable = False @@ -1479,6 +1491,7 @@ def setUp(self): class TestOnlySetsDict(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((1, 2, 3)) self.other = {1: 2, 3: 4} self.otherIsIterable = True @@ -1489,6 +1502,7 @@ def setUp(self): class TestOnlySetsOperator(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((1, 2, 3)) self.other = operator.add self.otherIsIterable = False @@ -1499,6 +1513,7 @@ def setUp(self): class TestOnlySetsTuple(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((1, 2, 3)) self.other = (2, 4, 6) self.otherIsIterable = True @@ -1509,6 +1524,7 @@ def setUp(self): class TestOnlySetsString(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet((1, 2, 3)) self.other = "abc" self.otherIsIterable = True @@ -1519,6 +1535,8 @@ def setUp(self): class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, TestCase): def setUp(self): + super().setUp() + def gen(): for i in range(0, 10, 2): # noqa: UP028 yield i @@ -1556,6 +1574,7 @@ def test_deep_copy(self): class TestCopyingEmpty(TestCopying, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet() @@ -1564,6 +1583,7 @@ def setUp(self): class TestCopyingSingleton(TestCopying, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet(["hello"]) @@ -1572,6 +1592,7 @@ def setUp(self): class TestCopyingTriple(TestCopying, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet(["zero", 0, None]) @@ -1580,6 +1601,7 @@ def setUp(self): class TestCopyingTuple(TestCopying, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet([(1, 2)]) @@ -1588,6 +1610,7 @@ def setUp(self): class TestCopyingNested(TestCopying, TestCase): def setUp(self): + super().setUp() self.OrderedSet = OrderedSet([((1, 2), (3, 4))]) @@ -1598,6 +1621,7 @@ def setUp(self): class TestIdentities(TestCase): def setUp(self): + super().setUp() self.a = OrderedSet("abracadabra") self.b = OrderedSet("alacazam") diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 5eaa007a8a1cb..3e13c6a9ca1ff 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -159,7 +159,7 @@ def setUp(self): if not HAS_GPU: self.skipTest("Triton not available") torch._dynamo.reset() - TestCase.setUp(self) + super().setUp() # this should be in setUpClass, but device-generic tests # don't work with setUpClass well (non-deterministically the wrong setUpClass is resolved), # so put it in test setUp, it's cheap diff --git a/test/onnx/exporter/test_building.py b/test/onnx/exporter/test_building.py index fdccf04c1d0af..8600ab44a67b7 100644 --- a/test/onnx/exporter/test_building.py +++ b/test/onnx/exporter/test_building.py @@ -14,6 +14,7 @@ class TestOpRecorder(common_utils.TestCase): def setUp(self): + super().setUp() self.opset_version = 17 self.opset = onnxscript.values.Opset("", self.opset_version) self.recorder = _building.OpRecorder(opset=self.opset, constant_farm={}) diff --git a/test/onnx/internal/test_registraion.py b/test/onnx/internal/test_registration.py similarity index 99% rename from test/onnx/internal/test_registraion.py rename to test/onnx/internal/test_registration.py index fcc4cdeedd92f..8d90553ac2181 100644 --- a/test/onnx/internal/test_registraion.py +++ b/test/onnx/internal/test_registration.py @@ -49,6 +49,7 @@ def test_dispatch_opset_version_returns_correct_version( class TestOverrideDict(common_utils.TestCase): def setUp(self): + super().setUp() self.override_dict: registration.OverrideDict[str, int] = ( registration.OverrideDict() ) diff --git a/test/profiler/test_cpp_thread.py b/test/profiler/test_cpp_thread.py index 9dbecf994a4fa..b4fcf49ad84d5 100644 --- a/test/profiler/test_cpp_thread.py +++ b/test/profiler/test_cpp_thread.py @@ -88,6 +88,7 @@ def tearDownClass(cls): torch.testing._internal.common_utils.remove_cpp_extensions_build_root() def setUp(self) -> None: + super().setUp() if not torch.cuda.is_available(): self.skipTest("Test machine does not have cuda") global device @@ -230,6 +231,7 @@ def tearDownClass(cls): torch.testing._internal.common_utils.remove_cpp_extensions_build_root() def setUp(self) -> None: + super().setUp() if not torch.xpu.is_available(): self.skipTest("Test machine does not have xpu") global device diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 7ce0b19ce884f..60d4f36e0c16e 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -24,6 +24,7 @@ class TestCudaPrimaryCtx(TestCase): ) def setUp(self): + super().setUp() for device in range(torch.cuda.device_count()): # Ensure context has not been created beforehand self.assertFalse( diff --git a/test/test_cuda_sanitizer.py b/test/test_cuda_sanitizer.py index 6d2ecc36a093c..e5dae52354a72 100644 --- a/test/test_cuda_sanitizer.py +++ b/test/test_cuda_sanitizer.py @@ -143,6 +143,7 @@ def event_id(i: int) -> EventId: class TestEventHandler(TestCase): def setUp(self): + super().setUp() self.handler = csan.EventHandler() def kernel_launch( @@ -397,6 +398,7 @@ def test_event_synchronize(self): class TestMessages(TestCase): def setUp(self): + super().setUp() self.handler = csan.EventHandler() def test_ensure_exists(self): diff --git a/test/test_cuda_trace.py b/test/test_cuda_trace.py index 124b0ac41b871..0794683f4ef26 100644 --- a/test/test_cuda_trace.py +++ b/test/test_cuda_trace.py @@ -20,6 +20,7 @@ @torch.testing._internal.common_utils.markDynamoStrictTest class TestCudaTrace(TestCase): def setUp(self): + super().setUp() torch._C._activate_gpu_trace() self.mock = unittest.mock.MagicMock() diff --git a/test/test_monitor.py b/test/test_monitor.py index cf9cecc356f87..19d4a6cf2dc25 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -111,6 +111,7 @@ def test_wait_counter(self) -> None: @skipIfTorchDynamo("Really weird error") class TestMonitorTensorboard(TestCase): def setUp(self): + super().setUp() global SummaryWriter, event_multiplexer try: from tensorboard.backend.event_processing import ( diff --git a/test/test_nnapi.py b/test/test_nnapi.py index d8a6392d72f1b..6f8d487507f46 100644 --- a/test/test_nnapi.py +++ b/test/test_nnapi.py @@ -28,6 +28,7 @@ def nhwc(t): ) class TestNNAPI(TestCase): def setUp(self): + super().setUp() # Avoid saturation in fbgemm torch.backends.quantized.engine = "qnnpack" diff --git a/test/test_sparse.py b/test/test_sparse.py index bd612cf0faaad..f1ed24667e133 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -178,7 +178,7 @@ def run(self, result=None): class TestSparse(TestSparseBase): def setUp(self): - TestCase.setUp(self) + super().setUp() self.index_tensor = lambda *args, **kwargs: torch.tensor(*args, **kwargs, dtype=torch.int64) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index c278378e2cc4a..61ecd49c2c477 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -21,6 +21,7 @@ class TestFuzzerCompileIssues(TestCase): def setUp(self): """Configure common test settings.""" + super().setUp() torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._inductor.config.emulate_precision_casts = True diff --git a/test/test_utils.py b/test/test_utils.py index 40cc969f11665..f6bdc156c122e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -649,6 +649,7 @@ def test_import_hipify(self): class TestHipifyTrie(TestCase): def setUp(self): + super().setUp() from torch.utils.hipify import hipify_python self.trie = hipify_python.Trie() diff --git a/test/test_weak.py b/test/test_weak.py index 629ed12db3267..e46268852c983 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -582,6 +582,7 @@ def _full_mapping(self, data): return x def setUp(self): + super().setUp() if IS_MACOS: raise unittest.SkipTest("non-portable load_library call used in test") From 7729de07d30f2d5f49b89e329f10917d40378495 Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Thu, 30 Oct 2025 06:39:43 +0000 Subject: [PATCH 137/651] Build libgomp (gcc-13) from src on AArch64 (#166549) This improves thread-scaling on AArch64 (see details on #155795) Fixes: #155795 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166549 Approved by: https://github.com/malfet --- .ci/docker/common/install_libgomp.sh | 56 ++++++++++++++++++++ .ci/docker/manywheel/Dockerfile_2_28_aarch64 | 4 ++ 2 files changed, 60 insertions(+) create mode 100644 .ci/docker/common/install_libgomp.sh diff --git a/.ci/docker/common/install_libgomp.sh b/.ci/docker/common/install_libgomp.sh new file mode 100644 index 0000000000000..308915ec4f618 --- /dev/null +++ b/.ci/docker/common/install_libgomp.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# Script used only in CD pipeline + +set -ex + +# install dependencies +dnf -y install gmp-devel libmpc-devel texinfo flex bison + +cd /usr/local/src +# fetch source for gcc 13 +git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0 + +mkdir -p gcc-13.3.0/build-gomp +cd gcc-13.3.0/build-gomp + +# configure gcc build +# I got these flags by: +# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container +# dnf install -y dnf-plugins-core rpmdevtools +# dnf download --source libgomp +# 2. extracting the gcc.spec from the source. +# rpmdev-extract gcc-xx.src.rpm +# 3. extracting optflags and ld_flags from gcc.spec: +# rpm --eval '%{optflags}' +# rpm --eval '%{build_ldflags}' +# +# I had to remove the following flags because they didn't compile for this version of libgomp: +# -Werror=format-security +# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1 +# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1 +# +# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed. + +OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\ +' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\ +' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\ +' -fstack-protector-strong -fasynchronous-unwind-tables'\ +' -fstack-clash-protection' + +LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now' + +CFLAGS="$OPT_FLAGS" \ +CXXFLAGS="$OPT_FLAGS" \ +LDFLAGS="$LDFLAGS" \ +../configure \ + --prefix=/usr \ + --libdir=/usr/lib64 \ + --enable-languages=c,c++ \ + --disable-multilib \ + --disable-bootstrap \ + --enable-libgomp + +# only build libgomp +make -j$(nproc) all-target-libgomp + +make install-target-libgomp \ No newline at end of file diff --git a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 index 768db09929361..78ee09d128cb0 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 @@ -50,6 +50,10 @@ RUN rm install_ninja.sh ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH +# Build a newer version of libgomp than that supported in in Almalinux 8. +COPY ./common/install_libgomp.sh install_libgomp.sh +RUN bash ./install_libgomp.sh && rm install_libgomp.sh + # git236+ would refuse to run git commands in repos owned by other users # Which causes version check to fail, as pytorch repo is bind-mounted into the image # Override this behaviour by treating every folder as safe From 7206668f7c0f4f9c0542fb4190540ca25d423b51 Mon Sep 17 00:00:00 2001 From: mohsinm-dev Date: Thu, 6 Nov 2025 18:52:18 +0000 Subject: [PATCH 138/651] Update torch.var documentation to use modern API (#167209) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Fix outdated unbiased parameter references in normalization module documentation. Replace deprecated torch.var(input, unbiased=False/True) with modern torch.var(input, correction=0/1) API throughout BatchNorm, InstanceNorm, LayerNorm, and GroupNorm docstrings. ## Changes - torch/nn/modules/batchnorm.py: Updated 4 instances across BatchNorm1d, BatchNorm2d, BatchNorm3d, and SyncBatchNorm - torch/nn/modules/instancenorm.py: Updated 3 instances across InstanceNorm1d, InstanceNorm2d, and InstanceNorm3d - torch/nn/modules/normalization.py: Updated 2 instances in LayerNorm and GroupNorm ## Test plan Mathematical behavior remains identical: unbiased=False ≡ correction=0 (biased estimator), unbiased=True ≡ correction=1 (unbiased estimator). Documentation now uses consistent modern API terminology with no functional changes to code behavior. Fixes #166804 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167209 Approved by: https://github.com/albanD --- torch/nn/modules/batchnorm.py | 14 +++++++------- torch/nn/modules/instancenorm.py | 6 +++--- torch/nn/modules/normalization.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 6a78aba2ad7db..2ac05f2e8f933 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -292,9 +292,9 @@ class BatchNorm1d(_BatchNorm): of size `C` (where `C` is the number of features or channels of the input). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the variance is calculated via the biased estimator, - equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the + equivalent to ``torch.var(input, correction=0)``. However, the value stored in the moving average of the variance is calculated via the unbiased estimator, equivalent to - ``torch.var(input, unbiased=True)``. + ``torch.var(input, correction=1)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during @@ -404,9 +404,9 @@ class BatchNorm2d(_BatchNorm): of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to - ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the + ``torch.var(input, correction=0)``. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to - ``torch.var(input, unbiased=True)``. + ``torch.var(input, correction=1)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during @@ -515,9 +515,9 @@ class BatchNorm3d(_BatchNorm): of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the standard-deviation is calculated via the biased estimator, equivalent to - ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the + ``torch.var(input, correction=0)``. However, the value stored in the moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to - ``torch.var(input, unbiased=True)``. + ``torch.var(input, correction=1)``. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during @@ -628,7 +628,7 @@ class SyncBatchNorm(_BatchNorm): By default, the elements of :math:`\gamma` are sampled from :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0. The standard-deviation is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index da3d3658553f0..058ffb3ed9aa9 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -141,7 +141,7 @@ class InstanceNorm1d(_InstanceNorm): for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``. The variance is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. @@ -256,7 +256,7 @@ class InstanceNorm2d(_InstanceNorm): for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size `C` (where `C` is the input size) if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. @@ -372,7 +372,7 @@ class InstanceNorm3d(_InstanceNorm): for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size C (where C is the input size) if :attr:`affine` is ``True``. The standard-deviation is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. By default, this layer uses instance statistics computed from input data in both training and evaluation modes. diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index 1474de008c185..60bd561bfd0e4 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -119,7 +119,7 @@ class LayerNorm(Module): :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. The variance is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. .. note:: Unlike Batch Normalization and Instance Normalization, which applies @@ -253,7 +253,7 @@ class GroupNorm(Module): per-channel affine transform parameter vectors of size :attr:`num_channels` if :attr:`affine` is ``True``. The variance is calculated via the biased estimator, equivalent to - `torch.var(input, unbiased=False)`. + `torch.var(input, correction=0)`. This layer uses statistics computed from input data in both training and evaluation modes. From aaea391b6226c200c61e34be182d4ffe3f329650 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 6 Nov 2025 18:57:27 +0000 Subject: [PATCH 139/651] [annotate][export] Add annotation to assertion nodes in export (#167171) Fixes #166906 ``` python test/export/test_export.py -k test_annotate_on_assert ``` The assertions are not marked with annotation because these nodes are created in `apply_runtime_assertion_pass`. Currently the annotation will only be added if the nodes are created during tracing. So we need to manually add the annotation. Nodes added in `apply_runtime_assertion_pass` will have the same annotation as the input node to the assertion. Output graph: Note that `_assert_scalar_default_1` is not annotated becayse it's an assertion on the size of `x` which is not annotated. ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[s77]", y: "i64[]"): # No stacktrace found for following nodes sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) # Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:729 in forward, code: x = torch.cat([x, x]) cat: "f32[2*s77]" = torch.ops.aten.cat.default([x, x]); x = None # Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:730 in forward, code: b = y.item() item: "Sym(u0)" = torch.ops.aten.item.default(y); y = None ge_1: "Sym(u0 >= 4)" = item >= 4 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 4 on node 'ge_1'"); ge_1 = _assert_scalar_default = None # No stacktrace found for following nodes mul_1: "Sym(2*s77)" = 2 * sym_size_int_1; sym_size_int_1 = None le: "Sym(2*s77 <= u0)" = mul_1 <= item; mul_1 = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression 2*s77 <= u0 on node 'le'"); le = _assert_scalar_default_1 = None # Annotation: {'moo': 0} File: /data/users/shangdiy/pytorch/test/export/test_export.py:732 in forward, code: return x * b mul: "f32[2*s77]" = torch.ops.aten.mul.Tensor(cat, item); cat = item = None return (mul,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167171 Approved by: https://github.com/angelayi --- test/export/test_export.py | 28 ++++++++++++++++++++++++++++ torch/fx/passes/runtime_assert.py | 6 ++++++ 2 files changed, 34 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index cdc18b1d4c564..25f1cec03bd7c 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -721,6 +721,34 @@ def example_inputs(self): ) self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id) + def test_annotate_on_assert(self): + # nodes added in `apply_runtime_assertion_pass` will be annotated + class M(torch.nn.Module): + def forward(self, x, y): + with torch.fx.traceback.annotate({"moo": 0}): + x = torch.cat([x, x]) + b = y.item() + torch._check(b >= x.shape[0]) + return x * b + + with torch.fx.traceback.preserve_node_meta(): + ep = torch.export.export( + M(), + (torch.randn(3), torch.tensor(6)), + dynamic_shapes={"x": {0: Dim("b")}, "y": None}, + ) + + custom_metadata = torch.fx.traceback._get_custom_metadata(ep.module()) + self.assertExpectedInline( + str(custom_metadata), + """\ +('call_function', 'cat', {'moo': 0}) +('call_function', 'item', {'moo': 0}) +('call_function', 'ge_1', {'moo': 0}) +('call_function', '_assert_scalar_default', {'moo': 0}) +('call_function', 'mul', {'moo': 0})""", + ) + @requires_gpu def test_flex_attention_export(self): from torch.nn.attention.flex_attention import create_block_mask, flex_attention diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 1d3b0b33e7bce..e475a5bc9b6df 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -165,6 +165,7 @@ def _node_metadata_hook( node: torch.fx.Node, stack_trace: Optional[str] = None, nn_module_stack: Optional[dict[str, Any]] = None, + custom: Optional[dict[str, Any]] = None, ) -> None: fake_args = pytree.tree_map( lambda arg: ( @@ -188,6 +189,8 @@ def _node_metadata_hook( node.meta["stack_trace"] = stack_trace if nn_module_stack is not None: node.meta["nn_module_stack"] = nn_module_stack + if custom is not None: + node.meta["custom"] = custom # Track asserts/checks we've added added_asserts: set[sympy.Expr] = set() @@ -615,6 +618,9 @@ def convert(s): _node_metadata_hook, stack_trace=node.meta.get("stack_trace"), nn_module_stack=node.meta.get("nn_module_stack"), + # nodes added in `apply_runtime_assertion_pass` will have the same annotation + # as the input node to the assertion + custom=node.meta.get("custom"), ), ): if (min_val := convert(vr.lower)) is not None: From 9fef18e31dd722863b2c28a2dcaba859cd73802f Mon Sep 17 00:00:00 2001 From: Chinmay Dattanand Kuchinad <40351312+chinmaydk99@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:08:11 +0000 Subject: [PATCH 140/651] [ROCm] Enable multi-arch compilation and unit tests for AOT Inductor (#166357) ## Summary This PR adds multi-architecture kernel compilation support for ROCm in PyTorch's AOT Inductor module, enabling a single compiled model to run across multiple AMD GPU architectures (MI200, MI300, MI350, etc.) without recompilation. ## Implementation - **Multi-arch compilation pipeline**: Compiles LLVM IR to multiple GPU architectures and bundles them using `clang-offload-bundler` - **Architecture detection**: Automatically detects target architectures from `torch.cuda.get_arch_list()`, with overrides via `PYTORCH_ROCM_ARCH` environment variable - **ROCm-specific utilities**: New `rocm_multiarch_utils.py` module handles ROCm toolchain integration - **Test infrastructure**: Adapted AOT Inductor tests to support both CUDA and ROCm compilation paths ## Testing Successfully tested on: - MI200 - MI300 **Enabled tests:** - `test_simple_multi_arch` - `test_compile_after_package_multi_arch` - `test_compile_with_exporter` - `test_compile_with_exporter_weights` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166357 Approved by: https://github.com/jeffdaily --- test/inductor/test_aot_inductor.py | 10 +- test/inductor/test_aot_inductor_package.py | 14 +- torch/_inductor/codecache.py | 121 ++++++--- torch/_inductor/rocm_multiarch_utils.py | 264 +++++++++++++++++++ torch/_inductor/runtime/triton_heuristics.py | 31 ++- torch/export/experimental/__init__.py | 6 +- torch/export/experimental/_utils.py | 40 ++- 7 files changed, 423 insertions(+), 63 deletions(-) create mode 100644 torch/_inductor/rocm_multiarch_utils.py diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 8f009f30a0a60..4f7eb86e8ce47 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -246,12 +246,12 @@ def forward(self, x): "toolchain doesn't support ptx to fatbin", ) @skipIfMPS - @skipIfRocm # Skip embed_kernel_binary == True for now as it shows random # failure on CI @common_utils.parametrize("embed_kernel_binary", [False]) @unittest.skipIf( - _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + torch.version.hip is None and _get_torch_cuda_version() < (12, 6), + "Test is only supported on CUDA 12.6+", ) def test_simple_multi_arch(self, embed_kernel_binary): if self.device != GPU_TYPE: @@ -281,7 +281,11 @@ def forward(self, x, y): _, code = run_and_get_cpp_code( AOTIRunnerUtil.compile, model, example_inputs ) - file_extension = ".spv" if self.device == "xpu" else ".fatbin" + file_extension = ( + ".spv" + if self.device == "xpu" + else (".hsaco" if torch.version.hip else ".fatbin") + ) FileCheck().check(file_extension).run(code) def test_small_constant(self): diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index d8b9ad5473bae..9c1f9802bc3ea 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -28,7 +28,7 @@ load_weights_to_pt2_contents, ) from torch.testing._internal.common_cuda import _get_torch_cuda_version -from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm, skipIfXpu +from torch.testing._internal.common_utils import IS_FBCODE, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -315,10 +315,10 @@ def forward(self, x, y): self.assertTrue(torch.allclose(actual, expected)) @unittest.skipIf( - _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + torch.version.hip is None and _get_torch_cuda_version() < (12, 6), + "Test is only supported on CUDA 12.6+", ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary def test_compile_after_package_multi_arch(self): if self.device != GPU_TYPE: @@ -457,10 +457,10 @@ def forward(self, x): self.assertTrue(a_path.exists()) @unittest.skipIf( - _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + torch.version.hip is None and _get_torch_cuda_version() < (12, 6), + "Test is only supported on CUDA 12.6+", ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter(self): @@ -515,10 +515,10 @@ def default(*args, **kwargs): ) @unittest.skipIf( - _get_torch_cuda_version() < (12, 6), "Test is only supported on CUDA 12.6+" + torch.version.hip is None and _get_torch_cuda_version() < (12, 6), + "Test is only supported on CUDA 12.6+", ) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfRocm # doesn't support multi-arch binary @skipIfXpu # doesn't support multi-arch binary @torch._inductor.config.patch("test_configs.use_libtorch", True) def test_compile_with_exporter_weights(self): diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 0177f6900c611..f36953d2a3337 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1680,30 +1680,42 @@ def set( basename, _ = get_name_and_dir_from_output_file_path(bin_path) if config.aot_inductor.emit_multi_arch_kernel: - bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} - assert bin_type in bin_type_to_ext, ( - "multi_arch_kernel_binary only supported in CUDA/XPU" + bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv", "hsaco": ".hsaco"} + assert bin_type in bin_type_to_ext.keys(), ( + "multi_arch_kernel_binary only supported in CUDA/XPU/ROCm" ) base_path, _ = os.path.splitext(bin_path) bin_path = base_path + bin_type_to_ext[bin_type] asm_path: str = "" + + # Kernel assembly/IR requirements for AOT Inductor: + # - CUDA/XPU: Always require PTX/SPV + # - ROCm multi-arch: Require LLVM IR (.ll) for bundle compilation if ( config.aot_inductor.emit_multi_arch_kernel or config.aot_inductor.package_cpp_only ): - assert asm, "Missing kernel assembly code" - assert asm_type, "Missing kernel assembly type" - _, asm_path = write( - asm, - asm_type, - hash_type=asm_type, - specified_dir=split_aot_inductor_output_path( - config.aot_inductor.output_path - )[0], - # make sure asm file has the same basename - key=basename, - ) + # Allow ROCm single-arch to skip (asm=None OK), require for everything else + if torch.version.hip is None or (asm and asm_type): + assert asm, "Missing kernel assembly code" + assert asm_type, "Missing kernel assembly type" + + # Cache directory mapping: asm_type → hash_type + # Problem: LLVM IR extension ".ll" isn't a recognized cache category + # Solution: Map to "code" (generic category for non-standard formats) + # Recognized categories: "ptx", "amdgcn", "spv", "code" + hash_kind = asm_type if asm_type in {"amdgcn", "ptx", "spv"} else "code" + + _, asm_path = write( + asm, + asm_type, + hash_type=hash_kind, + specified_dir=split_aot_inductor_output_path( + config.aot_inductor.output_path + )[0], + key=basename, + ) params[get_cpp_wrapper_cubin_path_name()] = bin_path params["asm"] = asm_path @@ -2383,28 +2395,57 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: config.aot_inductor.emit_multi_arch_kernel and device_type == "cuda" ): - current_arch = _nvcc_arch_as_compile_option() - cmd = ( - # pyrefly: ignore [unbound-name] - f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " - # Triton only allows generating PTX version as same as the current arch - f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " - # Include SASS for the current specific arch - f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " - ) - try: - subprocess.run( - cmd.split(), - capture_output=True, - text=True, - check=True, + if torch.version.hip is None: + current_arch = _nvcc_arch_as_compile_option() + cmd = ( + # pyrefly: ignore [unbound-name] + f"{_cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " + # Triton only allows generating PTX version as same as the current arch + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + # Include SASS for the current specific arch + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " ) - except subprocess.CalledProcessError as e: - print( - f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", - file=sys.stderr, + try: + subprocess.run( + cmd.split(), + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise + + else: + # ROCm multi-arch: compile LLVM IR to multi-arch bundle + from torch._inductor.rocm_multiarch_utils import ( + compile_multiarch_bundle_from_llvm_ir, ) - raise + + if not os.path.exists(asm_file): + raise RuntimeError( + f"Multi-arch ROCm compilation requires LLVM IR file, " + f"but {asm_file} not found. " + f"Ensure asm_type='ll' is captured in triton_heuristics.py" + ) + + # Compile for multiple archs and bundle them + success = compile_multiarch_bundle_from_llvm_ir( + llvm_ir_path=asm_file, + output_bundle_path=cubin_file, + target_archs=None, + ) + + if not success: + raise RuntimeError( + f"Failed to compile multi-arch bundle for kernel {kernel_name}. " + f"Check that ROCm toolchain is available and LLVM IR is valid." + ) + + log.info("Created multi-arch bundle: %s", cubin_file) if config.aot_inductor.embed_kernel_binary: # Embed cubin files into model.so using objcopy @@ -2471,10 +2512,18 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: generated_files.append(consts_o) so_builder.save_src_to_cmake(cmake_path, consts_o) - if config.aot_inductor.emit_multi_arch_kernel: + # Different CMake strategies for CUDA vs ROCm: + # - CUDA: Save asm for CMake to recompile (user has nvcc) + # - ROCm: Link pre-compiled bundle (user may lack dev tools) + if ( + config.aot_inductor.emit_multi_arch_kernel + and torch.version.hip is None + ): so_builder.save_kernel_asm_to_cmake(cmake_path, asm_files) generated_files.extend(asm_files) else: + # ROCm multi-arch + all single-arch: Link pre-compiled objects + # Bundle already embedded in .o files - just link into .so obj_srcs = [*gpu_kernels_o, *cubins_o] generated_files.extend(obj_srcs) for obj in obj_srcs: diff --git a/torch/_inductor/rocm_multiarch_utils.py b/torch/_inductor/rocm_multiarch_utils.py new file mode 100644 index 0000000000000..a1a6103e10915 --- /dev/null +++ b/torch/_inductor/rocm_multiarch_utils.py @@ -0,0 +1,264 @@ +""" +ROCm Multi-Architecture Support Utilities +Compile LLVM IR to multi-arch bundles that HIP can load automatically. +""" + +import os +import subprocess +from typing import Optional + +import torch +from torch.utils.cpp_extension import _join_rocm_home, ROCM_HOME + + +def get_rocm_compiler() -> str: + """ + Get path to ROCm's clang compiler. + Uses PyTorch's ROCM_HOME detection. + + Returns: + Path to clang compiler + + Raises: + RuntimeError: If ROCm is not found + """ + if ROCM_HOME is None: + raise RuntimeError( + "ROCm installation not found. " + "PyTorch was not built with ROCm support or ROCM_HOME is not set." + ) + + # ROCm's clang is at /llvm/bin/clang + clang_path = _join_rocm_home("llvm", "bin", "clang") + + if not os.path.exists(clang_path): + raise RuntimeError( + f"ROCm clang not found at {clang_path}. ROCM_HOME is set to {ROCM_HOME}" + ) + + return clang_path + + +def get_rocm_bundler() -> str: + """ + Get path to clang-offload-bundler. + Uses PyTorch's ROCM_HOME detection. + + Returns: + Path to bundler + + Raises: + RuntimeError: If bundler is not found + """ + if ROCM_HOME is None: + raise RuntimeError( + "ROCm installation not found. " + "PyTorch was not built with ROCm support or ROCM_HOME is not set." + ) + + # Bundler is at /llvm/bin/clang-offload-bundler + bundler_path = _join_rocm_home("llvm", "bin", "clang-offload-bundler") + + if not os.path.exists(bundler_path): + raise RuntimeError( + f"clang-offload-bundler not found at {bundler_path}. " + f"ROCM_HOME is set to {ROCM_HOME}" + ) + + return bundler_path + + +def get_rocm_target_archs() -> list[str]: + """ + Get target architectures from environment or config. + Returns: List of architecture strings (e.g., ['gfx90a', 'gfx942']) + """ + # Check PYTORCH_ROCM_ARCH environment variable + env_archs = os.environ.get("PYTORCH_ROCM_ARCH", "").strip() + if env_archs: + archs = [arch.strip() for arch in env_archs.replace(";", ",").split(",")] + archs = [arch for arch in archs if arch] + if archs: + return archs + + # Try to get from inductor config + try: + from torch._inductor import config + + if hasattr(config, "rocm") and hasattr(config.rocm, "target_archs"): + archs = config.rocm.target_archs + if archs: + return archs + + except Exception: + pass + + return torch.cuda.get_arch_list() + + +def compile_llvm_ir_to_code_object( + llvm_ir_path: str, output_path: str, target_arch: str +) -> bool: + """ + Compile unbundled LLVM IR to a single-arch code object. + + Args: + llvm_ir_path: Path to .ll file + output_path: Where to write .hsaco file + target_arch: Target architecture (e.g., 'gfx90a') + + Returns: + True if successful + """ + if not os.path.exists(llvm_ir_path): + return False + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + try: + clang = get_rocm_compiler() + except RuntimeError: + return False + + # Using clang and not hipcc since we are not compiling source code + # Instead we use the LLVM IR (.ll) provided by triton + cmd = [ + clang, + "-target", + "amdgcn-amd-amdhsa", + f"-mcpu={target_arch}", + llvm_ir_path, + "-o", + output_path, + ] + + try: + subprocess.run(cmd, capture_output=True, text=True, check=True) + + if not os.path.exists(output_path): + return False + + return True + + except subprocess.CalledProcessError: + return False + + +def create_multiarch_bundle(code_objects: dict, output_bundle_path: str) -> bool: + """ + Bundle multiple architecture code objects into a single multi-arch bundle. + + Uses clang-offload-bundler to create a fat binary that HIP runtime can load. + The runtime automatically selects the correct architecture at load time. + + Args: + code_objects: Dict mapping architecture to code object path + output_bundle_path: Path for output bundle + + Returns: + True if successful + """ + if not code_objects: + return False + + os.makedirs(os.path.dirname(output_bundle_path), exist_ok=True) + + try: + bundler = get_rocm_bundler() + except RuntimeError: + return False + + # Build targets and inputs lists for clang-offload-bundler + targets = ["host-x86_64-unknown-linux-gnu"] + + # We include a dummy host entry to satisfy the bundler format + inputs = ["/dev/null"] + + for arch, path in sorted(code_objects.items()): + if not os.path.exists(path): + continue + # hipv4 = HIP version 4 code object format + # amdgcn-amd-amdhsa = target triple for ROCm/HSA runtime + # arch = specific GPU (gfx90a, gfx942, etc.) + targets.append(f"hipv4-amdgcn-amd-amdhsa--{arch}") + inputs.append(path) + + if len(inputs) == 1: # Only host, no device code + return False + + cmd = [ + bundler, + "--type=o", + # CRITICAL: HIP runtime expects 4096-byte alignment for loading bundles + # Without this, hipModuleLoadData gives segmentation fault + "-bundle-align=4096", # CRITICAL: Required by HIP runtime! + f"--targets={','.join(targets)}", + ] + + for input_file in inputs: + cmd.append(f"--input={input_file}") + + cmd.append(f"--output={output_bundle_path}") + + try: + subprocess.run(cmd, capture_output=True, text=True, check=True) + + if not os.path.exists(output_bundle_path): + return False + + return True + + except subprocess.CalledProcessError: + return False + + +def compile_multiarch_bundle_from_llvm_ir( + llvm_ir_path: str, output_bundle_path: str, target_archs: Optional[list[str]] = None +) -> bool: + """ + Complete workflow: LLVM IR → multiple code objects → bundle. + + This is the main entry point for multi-arch compilation. + + Args: + llvm_ir_path: Path to .ll file + output_bundle_path: Where to write bundle + target_archs: Optional list of architectures + + Returns: + True if successful + """ + if target_archs is None: + # Get architectures from environment variable or config + target_archs = get_rocm_target_archs() + + # Step 1: Compile LLVM IR to code object for each architecture + code_objects = {} + temp_dir = os.path.dirname(output_bundle_path) + kernel_name = os.path.splitext(os.path.basename(llvm_ir_path))[0] + + for arch in target_archs: + # Create temporary single-architecture code object + # Format: kernel_name_gfx90a.co, kernel_name_gfx942.co, etc. + co_path = os.path.join(temp_dir, f"{kernel_name}_{arch}.co") + + # Compile with clang backend: LLVM IR → GPU machine code + if compile_llvm_ir_to_code_object(llvm_ir_path, co_path, arch): + code_objects[arch] = co_path + + if not code_objects: + return False + + # Step 2: Bundle all code objects together + # Uses clang-offload-bundler to create fat binary + success = create_multiarch_bundle(code_objects, output_bundle_path) + + # Step 3: Clean up temporary single-arch code objects + # The bundle contains all the code, so intermediates are no longer needed + for co_path in code_objects.values(): + try: + os.remove(co_path) + except Exception: + pass + + return success diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2e0a0dba9092e..b38cdcb71fa23 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1171,15 +1171,36 @@ def save_gpu_kernel(self, stream, launcher): launcher.bin.metadata, "threads_per_warp", 32 ) + from torch._inductor import config from torch._inductor.codecache import CudaKernelParamCache bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") binary = launcher.bin.asm[bin_type] - # Also store asm code which can be used for debugging and generating cpp package - asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( - self.device_props.type, None - ) - asm = launcher.bin.asm.get(asm_type, None) + + # ROCm multi-arch: capture LLVM IR + if torch.version.hip and config.aot_inductor.emit_multi_arch_kernel: + # Multi-arch ROCm: Capture LLVM IR for cross-architecture compilation + asm_type = "ll" + + # llir is the key to obtain LLVM IR from triton + asm = launcher.bin.asm.get("llir", None) + + # CRITICAL: Multi-arch compilation cannot proceed without LLVM IR + # Fail fast with clear error message pointing to the issue + if not asm: + available_keys = list(launcher.bin.asm.keys()) + raise RuntimeError( + f"ROCm multi-arch requires LLVM IR, but none found. " + f"Available keys: {available_keys}. " + f"Triton may need to be patched to emit LLVM IR." + ) + + # Everything else: capture architecture-specific assembly + else: + asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( + self.device_props.type, None + ) + asm = launcher.bin.asm.get(asm_type, None) CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) self.cuda_kernel_saved = True diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index ec5e73cad85d4..0dabd98016a1b 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -420,13 +420,15 @@ def _compiled_and_package( path = Path(base_directory) / f"{name}_input_{i}.pt" torch.save(t, path) - cmake_file_str = _get_make_file(package_name, model_names, use_cuda) + # Detect if ROCm is being used + is_hip = torch.version.hip is not None + cmake_file_str = _get_make_file(package_name, model_names, use_cuda, is_hip) with open(Path(base_directory) / "CMakeLists.txt", "w") as file: file.write(cmake_file_str) main_file_str = _get_main_cpp_file( - package_name, model_names, use_cuda, example_inputs_map + package_name, model_names, use_cuda, example_inputs_map, is_hip ) with open(Path(base_directory) / "main.cpp", "w") as file: file.write(main_file_str) diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py index 67bda0c34ce4f..3f45f337fe912 100644 --- a/torch/export/experimental/_utils.py +++ b/torch/export/experimental/_utils.py @@ -13,6 +13,7 @@ def _get_main_cpp_file( model_names: list[str], cuda: bool, example_inputs_map: typing.Optional[dict[str, int]], + is_hip: bool, ) -> str: """ Generates a main.cpp file for AOTInductor standalone models in the specified package. @@ -43,12 +44,20 @@ def _get_main_cpp_file( ] ) if cuda: - ib.writelines( - [ - "#include ", - "#include ", - ] - ) + if is_hip: + ib.writelines( + [ + "#include ", + ] + ) + + else: + ib.writelines( + [ + "#include ", + "#include ", + ] + ) for model_name in model_names: ib.writeline( @@ -181,7 +190,9 @@ def _get_main_cpp_file( return ib.getvalue() -def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str: +def _get_make_file( + package_name: str, model_names: list[str], cuda: bool, is_hip: bool +) -> str: ib = IndentedBuffer() ib.writelines( @@ -200,7 +211,10 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str ib.writeline("find_package(Torch REQUIRED)") if cuda: - ib.writeline("find_package(CUDA REQUIRED)") + if is_hip: + ib.writeline("find_package(hip REQUIRED)") + else: + ib.writeline("find_package(CUDA REQUIRED)") ib.newline() for model_name in model_names: @@ -208,12 +222,18 @@ def _get_make_file(package_name: str, model_names: list[str], cuda: bool) -> str ib.writeline("\nadd_executable(main main.cpp)") if cuda: - ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") + if is_hip: + ib.writeline("target_compile_definitions(main PRIVATE USE_HIP)") + else: + ib.writeline("target_compile_definitions(main PRIVATE USE_CUDA)") model_libs = " ".join(model_names) ib.writeline(f"target_link_libraries(main PRIVATE torch {model_libs})") if cuda: - ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") + if is_hip: + ib.writeline("target_link_libraries(main PRIVATE hip::host)") + else: + ib.writeline("target_link_libraries(main PRIVATE cuda ${CUDA_LIBRARIES})") return ib.getvalue() From 8523a64c4be123331bec4aae6c7a694d8e82df1a Mon Sep 17 00:00:00 2001 From: Taras Date: Thu, 6 Nov 2025 19:13:33 +0000 Subject: [PATCH 141/651] Fix python -m build: error: unrecognized arguments: --no-build-isolation (#166848) Fixes #166326 The PR fixes the following error: ``` python -m build: error: unrecognized arguments: --no-build-isolation ``` The regression has been introduced in the [commit](https://github.com/pytorch/pytorch/commit/50d418f69fbed31420208395f9f0172fc46e45fe#diff-e5a6ba9ea3717e5913cd885e81f143937ea727282edd6939479a2a60b1051bf5R73) in the scope of [PR](https://github.com/pytorch/pytorch/pull/156712). Pull Request resolved: https://github.com/pytorch/pytorch/pull/166848 Approved by: https://github.com/seemethere --- .ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 index a165f2a222caf..f0eabed170d25 100644 --- a/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 +++ b/.ci/pytorch/win-test-helpers/arm64/build_pytorch.ps1 @@ -70,7 +70,7 @@ sccache --zero-stats sccache --show-stats # Build the wheel -python -m build --wheel --no-build-isolation +python -m build --wheel --no-isolation if ($LASTEXITCODE -ne 0) { exit 1 } # Install the wheel locally From ba2e6b0b4f1718767762d7b20558d4de943be71b Mon Sep 17 00:00:00 2001 From: Chinmay Kuchinad Date: Thu, 6 Nov 2025 19:29:32 +0000 Subject: [PATCH 142/651] [ROCm] Enable StaticCudaLauncher for ROCm (#166492) This PR enables ROCm/HIP support for PyTorch's StaticCudaLauncher, which provides static compilation and launching of Triton kernels. The implementation has been tested on AMD MI300 and MI200 hardware. **Changes** **Python (torch/_inductor/runtime/)** - static_cuda_launcher.py: Added ROCm detection, .hsaco binary support, and ROCm-specific scratch parameter handling - triton_heuristics.py: Updated device type checks to support both cuda and hip **C++ (torch/csrc/)** - Module.cpp: Enabled StaticCudaLauncher for ROCm builds - inductor/static_cuda_launcher.cpp: Added HIP API equivalents for all CUDA driver calls - inductor/static_cuda_launcher.h: Updated header guard **Tests (test/inductor/)** - test_static_cuda_launcher.py: Removed @skipIfRocm decorators and updated binary file handling **Enabled Unit Tests** All tests in test/inductor/test_static_cuda_launcher.py now pass on ROCm: 1. test_basic 2. test_unsigned_integers 3. test_signed_integers 4. test_basic_1arg 5. test_constexpr 6. test_implied_constant 7. test_kernel_no_args 8. test_high_shared_mem 9. test_too_high_shared_mem 10. test_kernel_empty_tensor 11. test_kernel_many_args 12. test_basic_compile 13. test_incompatible_code 14. test_static_launch_user_defined_triton_kernels 15. test_empty_tensor 16. test_any 17. test_disable_static_cuda_launcher In addition to this, the following tests from test/inductor/test_codecache.py also pass: 1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False 2. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False 3. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True 4. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_False 5. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_False 6. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_True_use_static_cuda_launcher_True The following tests are skipped since triton bundling is necessary for StaticCudaLauncher: 1. test_remote_cache_load_function_device_cuda_float32_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True 2. test_remote_cache_load_function_device_cuda_bfloat16_dynamic_False_bundle_triton_False_use_static_cuda_launcher_True Pull Request resolved: https://github.com/pytorch/pytorch/pull/166492 Approved by: https://github.com/jeffdaily --- test/inductor/test_codecache.py | 9 +- test/inductor/test_static_cuda_launcher.py | 21 +--- .../_inductor/runtime/static_cuda_launcher.py | 55 ++++++++-- torch/_inductor/runtime/triton_heuristics.py | 11 +- torch/csrc/Module.cpp | 2 +- torch/csrc/inductor/static_cuda_launcher.cpp | 102 ++++++++++++++++-- torch/csrc/inductor/static_cuda_launcher.h | 2 +- 7 files changed, 161 insertions(+), 41 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 46f1ca031bf83..c90d2ccec83d5 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -475,14 +475,17 @@ def test_remote_cache_load_function( if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if ( + device == "cuda" + and torch.version.hip is None + and dtype == torch.bfloat16 + and not SM80OrLater + ): raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( "Static cuda launcher requires cuda and triton bundling" ) - if use_static_cuda_launcher and TEST_WITH_ROCM: - raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM") def fn(x, y): return (x * 2, y @ y) diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index 654bfd269f761..ec9586197d085 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -12,7 +12,6 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase -from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -39,8 +38,9 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) + binary_key = "hsaco" if torch.version.hip else "cubin" with tmp_file: - tmp_file.write(kernel.asm["cubin"]) + tmp_file.write(kernel.asm[binary_key]) self.tmp_files.append(tmp_file) return tmp_file.name @@ -64,7 +64,6 @@ def _make_launcher( result.load_kernel(device_interface.current_device()) return result - @skipIfRocm def test_basic(self): @triton.jit def simple_kernel(arg0, arg1): @@ -91,7 +90,6 @@ def simple_kernel(arg0, arg1): # 2. triton relies on inspect.get_source to get the type annotations # so I can't even use exec() to generate the test cases. # So we'll just make a few kernels by hand - @skipIfRocm def test_unsigned_integers(self): @triton.jit def unsigned_integers( @@ -115,7 +113,6 @@ def unsigned_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_signed_integers(self): @triton.jit def signed_integers( @@ -139,7 +136,6 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_basic_1arg(self): @triton.jit def simple_kernel_1_arg(arg0): @@ -164,7 +160,6 @@ def simple_kernel_1_arg(arg0): ) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_constexpr(self): # Constexprs are compiled directly into the cubin file, # so we never need to pass it to StaticCudaLauncher. @@ -193,7 +188,6 @@ def kernel_constexpr(arg0, CONSTANT: tl.constexpr): ) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_implied_constant(self): """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" @@ -246,7 +240,6 @@ def triton_red_fused_any_isinf_0( launcher.run(1, 1, 1, stream, arg0, arg2, 128) self.assertEqual(arg1, arg2) - @skipIfRocm def test_kernel_no_args(self): # Just an easy way to test incompatible number of arguments @triton.jit @@ -259,7 +252,6 @@ def kernel_no_op(): stream = device_interface.get_raw_stream(device_interface.current_device()) launcher.run(1, 1, 1, stream) - @skipIfRocm def test_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -283,7 +275,6 @@ def simple_kernel(arg0, arg1): launcher.run(1, 1, 1, stream, new_arg0, arg1) self.assertEqual(new_arg0, arg0) - @skipIfRocm def test_too_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -303,7 +294,6 @@ def simple_kernel(arg0, arg1): lambda: self._make_launcher(compiled_kernel), ) - @skipIfRocm def test_kernel_empty_tensor(self): # Triton kernel generated by torch.compile of the following: # @torch.compile() @@ -364,7 +354,6 @@ def triton_poi_fused_cat_0( launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel) self.assertEqual(buf0, buf1) - @skipIfRocm def test_kernel_many_args(self): N = 200 # Make 200 arguments @@ -405,7 +394,6 @@ class TestStaticTritonCompileResult(TestCase): Tests static cuda launcher with torch.compile() """ - @skipIfRocm def test_basic_compile(self): @torch.compile def foo(x, y): @@ -415,7 +403,6 @@ def foo(x, y): y = torch.randn(10, device="cuda") self.assertEqual(foo(x, y), x + y) - @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch("compile_threads", 1) def test_incompatible_code(self): @@ -438,7 +425,6 @@ def foo(x): lambda: foo(x), ) - @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch( {"compile_threads": 1, "static_launch_user_defined_triton_kernels": True} @@ -460,7 +446,6 @@ def foo(x): x2 = x.clone().detach_() self.assertEqual(foo(x), x2 + 5) - @skipIfRocm def test_empty_tensor(self): @torch.compile() def foo(x, y): @@ -472,7 +457,6 @@ def foo(x, y): result = foo(x, y) self.assertEqual(result, torch.cat(((x * 4), y + 10))) - @skipIfRocm def test_any(self): def fn(x): return ( @@ -492,7 +476,6 @@ def fn(x): compiled_result = compiled_fn(arg) self.assertEqual(eager_result, compiled_result) - @skipIfRocm def test_disable_static_cuda_launcher(self): @torch.compile def fn(x, y): diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index f48f351ce823a..4eede8631e9ce 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -38,7 +38,20 @@ def __init__(self, kernel: CompiledKernel) -> None: # pyrefly: ignore [missing-attribute] self.name = kernel.src.fn.__name__ # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm.get("cubin", None) + if "hsaco" in kernel.asm: + # pyrefly: ignore [missing-attribute] + self.cubin_raw = kernel.asm["hsaco"] + self.is_rocm = True + # pyrefly: ignore [missing-attribute] + elif "cubin" in kernel.asm: + # pyrefly: ignore [missing-attribute] + self.cubin_raw = kernel.asm["cubin"] + self.is_rocm = False + else: + raise RuntimeError( + "Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm" + ) + # pyrefly: ignore [missing-attribute] self.cubin_path = kernel._cubin_path @@ -245,12 +258,42 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher - # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys - for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: - if has_scratch: - arg_tys = arg_tys + "O" - args = (*args, None) + + if self.is_rocm: + # ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both + # global_scratch and profile_scratch parameters in the kernel signature, + # even when the kernel doesn't use them (i.e., when has_*_scratch is False). + # + # This differs fundamentally from CUDA, where these parameters are only + # present in the signature if the corresponding has_*_scratch flag is True. + # + # The flags indicate whether memory will be allocated/used: + # - has_global_scratch: Whether global scratch workspace is needed + # - has_profile_scratch: Whether profiling instrumentation is enabled + # + # However, regardless of flag values, we MUST always pass both parameters + # to match the HIP kernel ABI. Passing None is safe: + # + # - If scratch is not needed (has_*_scratch=False or scratch_size=0): + # The None becomes nullptr, which the kernel never dereferences + # + # - If scratch is needed (has_*_scratch=True and scratch_size>0): + # The None becomes nullptr initially, but the HIP runtime intercepts + # the kernel launch, allocates the required scratch memory based on + # kernel metadata, and replaces the nullptr with a valid pointer before + # the kernel actually executes + # + # Not passing both parameters causes segmentation faults because the kernel + # expects them at specific positions in the argument array. + arg_tys = arg_tys + "OO" + args = (*args, None, None) + + else: + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) # pyrefly: ignore [bad-argument-type] assert len(args) == len(arg_tys) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index b38cdcb71fa23..d60cda3fae7bf 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1599,9 +1599,8 @@ def can_statically_launch( return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type") != "cuda": - # Only cuda kernels - raise CannotStaticallyLaunchKernel("Non-cuda device") + if triton_meta.get("device_type") not in ("cuda", "hip"): + raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device") if torch._inductor.config.cpp_wrapper: # If we're running with cpp wrapper, it doesn't @@ -1627,10 +1626,11 @@ def check_can_launch() -> StaticallyLaunchedCudaKernel: "static launch does not support launch attributes" ) + binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin" cubin_location = os.path.join( triton_cache_dir(triton_meta.get("device", 0)), triton_hash_to_path_key(kernel.hash), - f"{kernel.src.fn.__name__}.cubin", + f"{kernel.src.fn.__name__}.{binary_ext}", ) if not os.path.exists(cubin_location): @@ -1662,10 +1662,11 @@ def reload_cubin_path(self): When loading from cache on disk, we want to reload cubin files from their appropriate location on disc. """ + binary_ext = "hsaco" if torch.version.hip else "cubin" cubin_location = os.path.join( triton_cache_dir(self.compile_meta.get("device", 0)), triton_hash_to_path_key(self.kernel.hash), - f"{self.kernel.name}.cubin", + f"{self.kernel.name}.{binary_ext}", ) if not os.path.exists(cubin_location): if self.kernel.cubin_raw is not None: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ad37abe3b560b..0c32e6028bc69 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2159,7 +2159,7 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif -#if defined(USE_CUDA) && !defined(USE_ROCM) +#if defined(USE_CUDA) ASSERT_TRUE(StaticCudaLauncher_init(module)); #endif #ifdef USE_MPS diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index 59916b6763bfa..35d11c8651323 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -1,7 +1,4 @@ -#if defined(USE_CUDA) && !defined(USE_ROCM) -// We disable this file from being hipified because there are CUDA drivers hip -// has not implemented yet. Also, we're passing in a cubin file directly, so it -// would take more work to support ROCM anyway. +#if defined(USE_CUDA) || defined(USE_ROCM) #include #include @@ -16,6 +13,11 @@ #include #include #include + +#if defined(USE_ROCM) +#include +#endif + /** Implements a static launcher for triton compiled CUDA kernels. Given a path to a cubin file, a function name, and some metadata, @@ -56,8 +58,14 @@ const at::cuda::NVRTC& nvrtc() { CUdeviceptr getPointer(PyObject* obj) { CUdeviceptr data_ptr = 0; + if (THPUtils_checkLong(obj)) { +#if defined(USE_ROCM) + data_ptr = reinterpret_cast(THPUtils_unpackUInt64(obj)); +#else data_ptr = THPUtils_unpackUInt64(obj); +#endif + return data_ptr; } if (obj == Py_None) { @@ -73,13 +81,25 @@ CUdeviceptr getPointer(PyObject* obj) { TORCH_CHECK( THPUtils_checkLong(ret), "data_ptr method of Pointer object must return 64-bit int"); + +#if defined(USE_ROCM) + data_ptr = reinterpret_cast(THPUtils_unpackUInt64(ret)); +#else data_ptr = THPUtils_unpackUInt64(ret); +#endif + if (!data_ptr) return data_ptr; CUdeviceptr dev_ptr = 0; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute( + &dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute( &dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); +#endif + return dev_ptr; } @@ -98,6 +118,15 @@ CUfunction loadKernel( } CUmodule mod = nullptr; CUfunction func = nullptr; + +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str())); + AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str())); + int shared_optin = 0; + AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( + &shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device)); + +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str())); @@ -106,6 +135,9 @@ CUfunction loadKernel( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); + +#endif + // Shared memory logic from triton/third-party/nvidia/backend/driver.c // If we're using more than 48 KB of shared memory, and we have // access to more than 48 KB of shared memory on the device, @@ -124,6 +156,21 @@ CUfunction loadKernel( " Reducing block sizes or `num_stages` may help."); if (sharedMemBytes > SHARED_MEM_STATIC_MAX && shared_optin > SHARED_MEM_STATIC_MAX) { +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared)); + int shared_total = 0, shared_static = 0; + AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( + &shared_total, + hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, + device)); + AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( + &shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); + AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + +#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total = 0, shared_static = 0; @@ -137,6 +184,7 @@ CUfunction loadKernel( func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); +#endif } return func; } @@ -152,6 +200,27 @@ inline void launchKernel( cudaStream_t stream) { // cta_args is always 1 for inductor generated triton kernels, // so we don't need to figure out grid dimension here +#if defined(USE_ROCM) + int device = 0; + AT_CUDA_DRIVER_CHECK(hipGetDevice(&device)); + int warp_size = 0; + AT_CUDA_DRIVER_CHECK( + hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device)); + + AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel( + func, + gridX, + gridY, + gridZ, + warp_size * numWarps, // blockDim.x + 1, // blockDim.y + 1, // blockDim.z + sharedMemBytes, + stream, + args, + nullptr)); + +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( func, gridX, @@ -164,6 +233,7 @@ inline void launchKernel( stream, args, nullptr)); +#endif } template @@ -269,11 +339,20 @@ PyObject* load_kernel(PyObject* self, PyObject* args) { CUdevice device = static_cast(device_ptr); // NOLINT CUfunction func = nullptr; func = loadKernel(filePath, funcName, sharedMemBytes, device); - // Taken from triton/nvidia/backend/driver.c + +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK( + hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func)); + AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( + &n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); + +#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute( &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); + +#endif n_spills /= 4; // Return a tuple of CUFunction, n_regs, n_spills return Py_BuildValue( @@ -299,7 +378,6 @@ PyObject* launch_kernel_inner( std::array argStorage = {}; std::array kernelArgs = {}; parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data()); - launchKernel( func, gridX, @@ -386,13 +464,25 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { Py_RETURN_NONE; } CUcontext pctx = nullptr; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); +#endif + if (!pctx) { // Ensure device context exists CUdevice device = 0; +#if defined(USE_ROCM) + AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0)); + AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device)); + AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx)); +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0)); AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device)); AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx)); + +#endif } CUfunction func = reinterpret_cast(func_ptr); // NOLINT cudaStream_t cudaStream = reinterpret_cast(stream); // NOLINT diff --git a/torch/csrc/inductor/static_cuda_launcher.h b/torch/csrc/inductor/static_cuda_launcher.h index 517036b9975e6..6f3980172275b 100644 --- a/torch/csrc/inductor/static_cuda_launcher.h +++ b/torch/csrc/inductor/static_cuda_launcher.h @@ -1,5 +1,5 @@ #pragma once -#if defined(USE_CUDA) && !defined(USE_ROCM) +#if defined(USE_CUDA) #include #include From c9b2db73ca0c823e338f430054b990d9fea274a1 Mon Sep 17 00:00:00 2001 From: Jeddie Ji Date: Thu, 6 Nov 2025 19:31:15 +0000 Subject: [PATCH 143/651] [Sigmoid][Delta Update][2/N] update delta update api to load original value first before casting to target dtype (#167039) Summary: The current delta update has a strong assumption that the non-lowered weights share the same tensor dtype from the lowered version. This is not true by design. When dtype mismatches the data loading will load the data into unexpected dtype which introduces undefined behavior. This diff aims to close the gap by always load tensor by its original dtype first then cast to desired dtype. Test Plan: No more NaN values! {P2022339213} Reviewed By: kqfu Differential Revision: D86181685 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167039 Approved by: https://github.com/henryoier --- torch/nativert/executor/Weights.cpp | 186 +++++++++++++++++----------- torch/nativert/executor/Weights.h | 6 +- 2 files changed, 117 insertions(+), 75 deletions(-) diff --git a/torch/nativert/executor/Weights.cpp b/torch/nativert/executor/Weights.cpp index 4a64935945c4f..ea1f1498b5fb5 100644 --- a/torch/nativert/executor/Weights.cpp +++ b/torch/nativert/executor/Weights.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -55,92 +56,128 @@ Weights::Weights( const std::unordered_map& constantPaths, std::string_view constantPathPrefix, std::function skipSizeCheck, - std::function skipDtypeCheck) + std::function skipDtypeCheck, + std::shared_ptr>> maybeNewWeightsMeta) : graph_(graph), weightsMeta_(graph->weightsMeta()), version_(globalVersion_++), skipSizeCheck_(std::move(skipSizeCheck)), skipDtypeCheck_(std::move(skipDtypeCheck)) { - auto loadAndInsert = - [&](const std::string& tensorName, - std::string_view pathPrefix, - const std::unordered_map& tensorPaths, - bool isUsed) { - auto pathIt = tensorPaths.find(tensorName); - TORCH_CHECK( - pathIt != tensorPaths.end(), - "Couldn't find ", - tensorName, - " in tensorPaths"); + auto loadAndInsert = [&](const std::string& tensorName, + std::string_view pathPrefix, + const std::unordered_map& + tensorPaths, + bool isUsed, + std::shared_ptr>> + maybeNewWeightsMeta) { + auto pathIt = tensorPaths.find(tensorName); + TORCH_CHECK( + pathIt != tensorPaths.end(), + "Couldn't find ", + tensorName, + " in tensorPaths"); - const std::string tensorPath = std::string{pathPrefix} + pathIt->second; - VLOG(1) << "Loading weight from: " << tensorPath; - TORCH_CHECK( - pytorchStreamReader->hasRecord(tensorPath), - tensorPath, - " not found"); - - auto [tensorData, tensorDataSize] = - pytorchStreamReader->getRecord(tensorPath); - - // TODO: We now have two copies of metadata for weights, one in - // model definition /models/.json, another in - // /extra/xl_weights/_model_param_config.json - // Currently, we only use the metadata from model definition. - std::optional tensorMeta; - if (weightsMeta_.find(tensorName) != weightsMeta_.end()) { - tensorMeta = weightsMeta_.at(tensorName); - } else { - TORCH_CHECK(false, "Tensor meta not found for: ", tensorName); - } + const std::string tensorPath = std::string{pathPrefix} + pathIt->second; + VLOG(1) << "Loading weight from: " << tensorPath; + TORCH_CHECK( + pytorchStreamReader->hasRecord(tensorPath), tensorPath, " not found"); + + auto [tensorData, tensorDataSize] = + pytorchStreamReader->getRecord(tensorPath); + + // TODO: We now have two copies of metadata for weights, one in + // model definition /models/.json, another in + // /extra/xl_weights/_model_param_config.json + // Currently, we only use the metadata from model definition. + std::optional tensorMeta; + if (weightsMeta_.find(tensorName) != weightsMeta_.end()) { + tensorMeta = weightsMeta_.at(tensorName); + } else { + TORCH_CHECK( + false, + "Tensor meta not found for: ", + tensorName, + " in base weights."); + } + std::optional newTensorMeta; + if (maybeNewWeightsMeta) { + if (stateDictPaths.find(tensorName) == stateDictPaths.end()) { + TORCH_CHECK(false, "Tensor name not found in state dict paths"); + } - if (tensorDataSize == 0 && tensorMeta->numel() > 0) { - VLOG(1) << "Tensor " << tensorName - << " does not have data and create on Meta device"; - allValues_[tensorName] = at::empty_strided( - tensorMeta->sizes(), - tensorMeta->strides(), - tensorMeta->asTensorOptions().device(at::kMeta)); - return; - } + std::string paramName = stateDictPaths.at(tensorName); + if (maybeNewWeightsMeta->find(paramName) != maybeNewWeightsMeta->end()) { + newTensorMeta = *maybeNewWeightsMeta->at(paramName); + } else { + TORCH_CHECK( + false, + "Tensor meta not found for: ", + tensorName, + " in new weights from: ", + paramName); + } + } + std::optional curTensorMeta = + newTensorMeta ? newTensorMeta : tensorMeta; + + if (tensorDataSize == 0 && tensorMeta->numel() > 0) { + VLOG(1) << "Tensor " << tensorName + << " does not have data and create on Meta device"; + allValues_[tensorName] = at::empty_strided( + curTensorMeta->sizes(), + curTensorMeta->strides(), + curTensorMeta->asTensorOptions().device(at::kMeta)); + return; + } - if (!isUsed) { - VLOG(1) << "Tensor " << tensorName << " is not used during inference"; - auto targetDevice = tensorMeta->device(); - allValues_[tensorName] = - at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); - return; - } + if (!isUsed) { + VLOG(1) << "Tensor " << tensorName << " is not used during inference"; + auto targetDevice = curTensorMeta->device(); + allValues_[tensorName] = + at::scalar_tensor(0, at::TensorOptions().device(targetDevice)); + return; + } - size_t bytesPerEntry = - c10::scalarTypeToTypeMeta(tensorMeta->dtype()).itemsize(); - auto device = tensorData.device(); - auto storage = c10::Storage( - c10::Storage::use_byte_size_t(), - at::detail::computeStorageNbytes( - tensorMeta->sizes(), tensorMeta->strides(), bytesPerEntry), - std::move(tensorData), // ownership is transferred - nullptr, - false); - const auto tensorOptions = at::TensorOptions(device) - .dtype(tensorMeta->dtype()) - .requires_grad(false); - auto tensor = - at::empty({0}, tensorOptions) - .set_(storage, 0, tensorMeta->sizes(), tensorMeta->strides()); - - auto targetDevice = tensorMeta->device(); - VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; - if (!isSameDevice(targetDevice, tensor.device())) { - tensor = tensor.to(targetDevice); - } + size_t bytesPerEntry = + c10::scalarTypeToTypeMeta(curTensorMeta->dtype()).itemsize(); + auto device = tensorData.device(); + auto storage = c10::Storage( + c10::Storage::use_byte_size_t(), + at::detail::computeStorageNbytes( + curTensorMeta->sizes(), curTensorMeta->strides(), bytesPerEntry), + std::move(tensorData), // ownership is transferred + nullptr, + false); + const auto tensorOptions = at::TensorOptions(device) + .dtype(curTensorMeta->dtype()) + .requires_grad(false); + auto tensor = + at::empty({0}, tensorOptions) + .set_(storage, 0, curTensorMeta->sizes(), curTensorMeta->strides()); + + auto targetDevice = tensorMeta->device(); + VLOG(1) << "Loading weight " << tensorName << " on " << targetDevice; + if (!isSameDevice(targetDevice, tensor.device())) { + tensor = tensor.to(targetDevice); + } + if (tensor.dtype() != tensorMeta->dtype()) { + tensor = tensor.to(tensorMeta->dtype()); + } - allValues_[tensorName] = tensor; - }; + allValues_[tensorName] = tensor; + }; auto loadAndInsertParamsBuffers = [&](const auto& tensorName, bool isUsed) { return loadAndInsert( - std::string(tensorName), stateDictPathPrefix, stateDictPaths, isUsed); + std::string(tensorName), + stateDictPathPrefix, + stateDictPaths, + isUsed, + maybeNewWeightsMeta); }; size_t weightIndex = 0; @@ -190,7 +227,8 @@ Weights::Weights( std::string(constantName), constantPathPrefix, constantPaths, - isUsed); + isUsed, + nullptr); weightIndex++; } else { TORCH_CHECK(false, "Unknown constant path: ", fileName); diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h index 39653d0bed561..acc3379198354 100644 --- a/torch/nativert/executor/Weights.h +++ b/torch/nativert/executor/Weights.h @@ -45,7 +45,11 @@ class Weights { const std::unordered_map& constantPaths, std::string_view constantPathPrefix, std::function skipSizeCheck = {}, - std::function skipDtypeCheck = {}); + std::function skipDtypeCheck = {}, + std::shared_ptr>> maybeNewWeightsMeta = + nullptr); at::Tensor at(const std::string& name) const; at::Tensor& at(const std::string& name); From 77b70970f70d53de71b9703ad4c3199d714c535a Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Thu, 6 Nov 2025 19:55:38 +0000 Subject: [PATCH 144/651] [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167182) Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs. Test Plan: Inductor test (fbcode): `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"` Tritonbench (fbcode): `clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Tritonbench(oss): `clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Unit Tests(oss): `clear; python test/inductor/test_cutedsl_grouped_mm.py` Differential Revision: D86376880 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167182 Approved by: https://github.com/mlazos, https://github.com/jananisriram --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 + setup.py | 34 ++ test/inductor/test_cutedsl_grouped_mm.py | 154 ++++++++ torch/_inductor/config.py | 4 + torch/_inductor/kernel/mm_common.py | 7 + torch/_inductor/kernel/mm_grouped.py | 90 +++-- .../templates/cutedsl_mm_grouped.py.jinja | 333 ++++++++++++++++++ .../_inductor/template_heuristics/cutedsl.py | 141 ++++++++ torch/_inductor/utils.py | 78 ++++ 10 files changed, 811 insertions(+), 33 deletions(-) create mode 100644 test/inductor/test_cutedsl_grouped_mm.py create mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja create mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 26996b5a32d56..9ae2578758939 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index d1b3b17445dac..3b4323051073a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py +torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index 31e78d0245d93..dd8a52cbeb7c7 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +def mirror_inductor_external_kernels() -> None: + """ + Copy external kernels into Inductor so they are importable. + """ + paths = [ + ( + CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", + CWD + / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", + ), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + + # Copy the files from the orig location to the new location + if orig_path.is_file(): + shutil.copyfile(orig_path, new_path) + continue + if orig_path.is_dir(): + if new_path.exists(): + # copytree fails if the tree exists already, so remove it. + shutil.rmtree(new_path) + shutil.copytree(orig_path, new_path) + continue + raise RuntimeError( + "Check the file paths in `mirror_inductor_external_kernels()`" + ) + + # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1616,6 +1647,8 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() + mirror_inductor_external_kernels() + ( ext_modules, cmdclass, @@ -1649,6 +1682,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", + "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py new file mode 100644 index 0000000000000..c26def3a54099 --- /dev/null +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -0,0 +1,154 @@ +# Owner(s): ["module: inductor"] + + +import unittest + +import torch +from torch import Tensor +from torch._inductor import config +from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import ensure_cute_available +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +@unittest.skipIf( + not (ensure_cute_available() and is_datacenter_blackwell_arch()), + "CuTeDSL library or Blackwell device not available", +) +@instantiate_parametrized_tests +class TestCuTeDSLGroupedGemm(InductorTestCase): + def _get_inputs( + self, + group_size: int, + M_hint: int, + K: int, + N: int, + device: str, + dtype: torch.dtype, + alignment: int = 16, + ) -> tuple[Tensor, Tensor, Tensor]: + # --- Random, tile-aligned M sizes --- + M_sizes = ( + torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) + * alignment + ) + + M_total = torch.sum(M_sizes).item() + + # --- Construct input tensors --- + A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 + B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 + + # --- Build offsets (no leading zero, strictly increasing) --- + offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) + + return (A, B, offsets) + + @parametrize("group_size", (2, 8)) + @parametrize("M_hint", (256, 1024)) + @parametrize("K", (64, 128)) + @parametrize("N", (128, 256)) + def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): + device = "cuda" + dtype = torch.bfloat16 + + A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # Eager execution + c_eager = grouped_gemm_fn(A, B, offsets) + + # Test with Cute backend + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) + @parametrize("layout_B", ("contiguous", "broadcasted")) + def test_grouped_gemm_assorted_layouts( + self, + layout_A: str, + layout_B: str, + ): + device = "cuda" + dtype = torch.bfloat16 + + G, K, N = 8, 64, 128 + M_sizes = [128] * G + sum_M = sum(M_sizes) + offsets = torch.tensor( + [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device + ) + + A_base = torch.randn(sum_M, K, device=device, dtype=dtype) + A = A_base + + if layout_A == "offset": + # allocate bigger buffer than needed, use nonzero storage offset + storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) + offset = 128 # skip first 128 elements + A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) + elif layout_A == "padded": + # simulate row pitch > K (row_stride = K + pad) + row_pitch = K + 8 + storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) + A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) + elif layout_A == "view": + A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) + A = A_storage.view(sum_M, K) + assert A._base is not None + assert A.shape == (sum_M, K) + + B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 + + if layout_B == "broadcasted": + # Broadcast B across groups (zero stride along G) + B = B[0].expand(G, K, N) + assert B.stride(0) == 0 + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # --- eager --- + c_eager = grouped_gemm_fn(A, B, offsets) + + # --- compiled (CUTE backend) --- + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 2d9e180db54f5..8b996967749f5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -541,6 +541,10 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b95073e769f31..eb22b95af2afc 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence +from functools import partial +from pathlib import Path from typing import Any import torch @@ -12,6 +14,7 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template log = logging.getLogger(__name__) @@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 881c14fd43d0d..c81ec607661bc 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import logging -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl +from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -22,11 +24,13 @@ get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, + load_kernel_template, persistent_grouped_mm_grid, ) @@ -513,6 +517,11 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + def grouped_mm_args( mat1: TensorBox, @@ -714,43 +723,44 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -788,6 +798,22 @@ def _tuned_grouped_mm_common( **config.kwargs, ) + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000..989f297c5f80f --- /dev/null +++ b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000..db337b9d8a271 --- /dev/null +++ b/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 9579dbb3536e3..cd0f3643d37f7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,6 +1975,84 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 2e83ae2de7aafbb5c72a28f73b99e009d0bd603f Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 6 Nov 2025 06:19:30 -0800 Subject: [PATCH 145/651] [pp] Add reduce_grad Action (#166449) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166449 Approved by: https://github.com/wconstab, https://github.com/sanketpurandare --- test/distributed/pipelining/test_schedule.py | 40 +++++++++++ test/distributed/test_composability.py | 1 + torch/distributed/pipelining/schedules.py | 72 ++++++++++++++++---- torch/distributed/pipelining/stage.py | 12 +++- 4 files changed, 109 insertions(+), 16 deletions(-) diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index 714ab8f659111..36d334d18b02c 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -21,6 +21,7 @@ from torch.distributed.pipelining._utils import generate_stage_to_rank_mapping from torch.distributed.pipelining.schedules import ( _Action, + _add_reduce_grad, _add_send_recv, _add_unshard_reshard, _format_pipeline_order, @@ -574,6 +575,45 @@ def test_unshard_reshard(self, test_info): ), ) + @parametrize( + "test_info", + [ + { + "compute": ["0F0", "0F1", " ", "0B0", "0B1"], + "comms": ["0F0", "0F1", "0B0", "0B1", "0REDUCE_GRAD"], + }, + { + "compute": ["0F0", "0F1", "1F0", "1F1", "1B0", "1B1", "0B0", "0B1"], + "comms": [ + "0F0", + "0F1", + "1F0", + "1F1", + "1B0", + "1B1", + "1REDUCE_GRAD", + "0B0", + "0B1", + "0REDUCE_GRAD", + ], + }, + ], + ) + def test_reduce_grad(self, test_info): + compute_sch = self._parse_actions(test_info["compute"]) + expected_comms_sch = self._parse_actions(test_info["comms"]) + + comms_sch = _add_reduce_grad(compute_sch, 2) + for expected, actual in zip(expected_comms_sch, comms_sch, strict=True): + self.assertEqual( + expected, + actual, + ( + f"Mismatch: expected action {expected} but found {actual}." + f"\nWhole Schedule: {comms_sch}" + ), + ) + @parametrize( "test_info", [ diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index 3508a43cb548f..566a63d67302d 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -499,6 +499,7 @@ def create_schedule(computation_types, microbatch_index=None): [ _ComputationType.UNSHARD, _ComputationType.FORWARD, + _ComputationType.REDUCE_GRAD, # Contains final fsdp post_backward ], microbatch_index=0, ) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 44569427f8db2..d84857ef474af 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -54,6 +54,7 @@ class _ComputationType(Enum): RECV_B = 9 FULL_BACKWARD = 10 OVERLAP_F_B = 11 + REDUCE_GRAD = 12 def __str__(self): str_map = { @@ -68,6 +69,7 @@ def __str__(self): _ComputationType.RECV_B: "RECV_B", _ComputationType.FULL_BACKWARD: "B", _ComputationType.OVERLAP_F_B: "OVERLAP_F_B", + _ComputationType.REDUCE_GRAD: "REDUCE_GRAD", } return str_map[self] @@ -95,6 +97,8 @@ def from_str(action): return _ComputationType.FULL_BACKWARD elif action == "OVERLAP_F_B": return _ComputationType.OVERLAP_F_B + elif action == "REDUCE_GRAD": + return _ComputationType.REDUCE_GRAD else: raise RuntimeError(f"Invalid computation type {action}") @@ -110,6 +114,7 @@ def from_str(action): RECV_B = _ComputationType.RECV_B FULL_BACKWARD = _ComputationType.FULL_BACKWARD OVERLAP_F_B = _ComputationType.OVERLAP_F_B +REDUCE_GRAD = _ComputationType.REDUCE_GRAD # Convenience shorthand for compute actions only since they are used in 'simple schedule format' F = FORWARD @@ -119,7 +124,7 @@ def from_str(action): # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) _action_regex = re.compile( - r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|REDUCE_GRAD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" ) @@ -645,10 +650,6 @@ def step( args_split, kwargs_split, targets_split, losses, return_outputs ) - # Stage post processing - grad_scale_factor = self._n_microbatches if self.scale_grads else 1 - self._stage._post_backward(grad_scale_factor) - # Return merged results per original format if self._stage.is_last and return_outputs: return self._merge_outputs(self._stage.output_chunks) @@ -809,6 +810,8 @@ def _step_microbatches( # Update losses if there is a container passed in self._update_losses(self._stage, losses) + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: """ Returns the pipeline order for GPipe schedule. @@ -837,9 +840,9 @@ def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: for mb_idx in range(self._n_microbatches): actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) - pipeline_order[rank] = actions + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) - return pipeline_order + return pipeline_order # type: ignore[return-value] class Schedule1F1B(PipelineScheduleSingle): @@ -990,6 +993,8 @@ def _step_microbatches( # Return losses if there is a container passed in self._update_losses(self._stage, losses) + self._stage.perform_reduce_grad(self._n_microbatches if self.scale_grads else 1) + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: """ Returns the pipeline order for 1F1B schedule. @@ -1055,10 +1060,47 @@ def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: backward_mb += 1 remaining_backward -= 1 - pipeline_order[rank] = actions + pipeline_order[rank] = _add_reduce_grad(actions, self._n_microbatches) return pipeline_order +def _requires_reduce_grad(action_type: _ComputationType) -> bool: + return action_type in (W, B) + + +def _add_reduce_grad( + actions: list[Optional[_Action]], n_microbatches: int +) -> list[Optional[_Action]]: + """ + REDUCE_GRAD refers to joint across minibatches grad reduction. + reduce_grad frees memory and we want to schedule it just after the last "backward"-like stage. + """ + actions_with_reduce_grad: list[Optional[_Action]] = [] + cnt: dict[int, int] = defaultdict(int) + + def _leaf_action(a, to_schedule): + if _requires_reduce_grad(a.computation_type): + stage_index = a.stage_index + cnt[stage_index] += 1 + if cnt[stage_index] == n_microbatches: + to_schedule.append(stage_index) + + for a in actions: + if a is None: + continue + actions_with_reduce_grad.append(a) + schedule_reduce_grad_stage_idxs: list[int] = [] + if a.computation_type == OVERLAP_F_B and a.sub_actions is not None: + for sub_action in a.sub_actions: + _leaf_action(sub_action, schedule_reduce_grad_stage_idxs) + else: + _leaf_action(a, schedule_reduce_grad_stage_idxs) + + for stage_idx in schedule_reduce_grad_stage_idxs: + actions_with_reduce_grad.append(_Action(stage_idx, REDUCE_GRAD, None)) + return actions_with_reduce_grad + + def _add_unshard_reshard( compute_actions: list[Optional[_Action]], max_active_stages: int = 3, @@ -1596,12 +1638,6 @@ def step( args_split, kwargs_split, targets_split, losses, return_outputs ) - # Stage post processing - # TODO: remove this section and include as part of the schedule IR? - for stage in self._stages: - grad_scale_factor = self._n_microbatches if self.scale_grads else 1 - stage._post_backward(grad_scale_factor) - # Return merged results per original format for stage in self._stages: if stage.is_last and return_outputs: @@ -1917,6 +1953,10 @@ def _prepare_schedule_with_comms( self.pipeline_order_with_comms[rank] = _add_unshard_reshard( actions[rank] ) + self.pipeline_order_with_comms[rank] = _add_reduce_grad( # type: ignore[assignment] + self.pipeline_order_with_comms[rank], # type: ignore[arg-type] + self._n_microbatches, + ) self.pipeline_order_with_comms = _add_send_recv( self.pipeline_order_with_comms, @@ -2025,6 +2065,7 @@ def _perform_action(action: _Action) -> None: assert mb_index >= 0 or comp_type in ( UNSHARD, RESHARD, + REDUCE_GRAD, ), f"{action=} missing mb_index" stage_idx = action.stage_index stage = stage_index_to_stage[stage_idx] @@ -2179,6 +2220,9 @@ def _perform_action(action: _Action) -> None: mb_index, last_backward=last_backward, ) + elif comp_type == REDUCE_GRAD: + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage.perform_reduce_grad(grad_scale_factor) else: raise ValueError(f"{action=} is unknown or unsupported") diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 6274689945109..a232f5519c9ee 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -978,7 +978,14 @@ def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: return ops - def _post_backward(self, grad_scale_factor: int): + def perform_reduce_grad(self, grad_scale_factor: int): + """ + Called as a part of schedule IR. + REDUCE_GRAD action is scheduled after all microbatches W, B actions. + + Currently contains "post_backward" functionality for FSDP. + We can try to extract post_backward in a separate IR action in future. + """ # Manually call post backward for FSDP if isinstance(self.submod, FSDPModule): fsdp_module = self.submod @@ -1001,7 +1008,8 @@ def _post_backward(self, grad_scale_factor: int): distributed_state._root_post_backward_final_callback() # Call gradient scaling at the end of the backward pass # NOTE: this must happen after FSDP post_backward is FSDP is enabled - self.scale_grads(grad_scale_factor) + if grad_scale_factor != 1: + self.scale_grads(grad_scale_factor) class _PipelineStage(_PipelineStageBase): From 03dea563f48a410b5cded94dfab54093144d8889 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 6 Nov 2025 10:25:40 -0800 Subject: [PATCH 146/651] Add guidance on how to migrate kernels to the libtorch stable ABI (#167112) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167112 Approved by: https://github.com/janeyx99 --- docs/source/notes/libtorch_stable_abi.md | 103 +++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/docs/source/notes/libtorch_stable_abi.md b/docs/source/notes/libtorch_stable_abi.md index fff32d00cb449..5312dfe546072 100644 --- a/docs/source/notes/libtorch_stable_abi.md +++ b/docs/source/notes/libtorch_stable_abi.md @@ -46,6 +46,108 @@ These headers are promised to be ABI stable across releases and adhere to a stro Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable` which will handle all the rough edges of the C API for the user. +## Migrating your kernel to the LibTorch stable ABI + +If you'd like your kernel to be ABI stable with LibTorch, meaning you'd the ability to build for one version and run on another, your kernel must only use the limited stable ABI. This following section goes through some steps of migrating an existing kernel and APIs we imagine you would need to swap over. + +Firstly, instead of registering kernels through `TORCH_LIBRARY`, LibTorch ABI stable kernels must be registered via `STABLE_TORCH_LIBRARY`. Note that, for the time being, implementations registered via `STABLE_TORCH_LIBRARY` must be boxed unlike `TORCH_LIBRARY`. See the simple example below or our docs on [Stack-based APIs](stack-based-apis) for more details. For kernels that are registered via `pybind`, before using the stable ABI, it would be useful to migrate to register them via `TORCH_LIBRARY`. + +While previously your kernels might have included APIs from `` (for example, ``), they are now limited to including from the 3 categories of headers mentioned above (`torch/csrc/stable/*.h`, `torch/headeronly/*.h` and the stable C headers). This means that your extension should no longer use any utilities from the `at::` or `c10::` namespaces but instead use their replacements in `torch::stable` and `torch::headeronly`. To provide a couple examples of the necessary migrations: +- all uses of `at::Tensor` must be replaced with `torch::stable::Tensor` +- all uses of `TORCH_CHECK` must be replaced with `STD_TORCH_CHECK` +- all uses of `at::kCUDA` must be replaced with `torch::headeronly::kCUDA` etc. +- native functions such as `at::pad` must be replaced with `torch::stable::pad` +- native functions that are called as Tensor methods (e.g., `Tensor.pad`) must be replaced with the ATen variant through `torch::stable::pad`. + +As mentioned above, the LibTorch stable ABI is still under development. If there is any API or feature you would like to see added to the stable ABI/`torch::headeronly`/`torch::stable`, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues). + +Below is a simple example of migrating an existing kernel that uses `TORCH_LIBRARY` to the stable ABI (`TORCH_STABLE_LIBRARY`). For a larger end to end example you can take a look at the FA3 repository. Specifically the diff between [`flash_api.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api.cpp#L1) and the stable variant [`flash_api_stable.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api_stable.cpp#L1). + + +### Original Version with `TORCH_LIBRARY` + +```cpp +// original_kernel.cpp - Using TORCH_LIBRARY (not stable ABI) +#include +#include + +namespace myops { + +// Simple kernel that adds a scalar value to each element of a tensor +at::Tensor add_scalar(const at::Tensor& input, double scalar) { + TORCH_CHECK(input.scalar_type() == at::kFloat, "Input must be float32"); + + return input.add(scalar); +} + +// Register the operator +TORCH_LIBRARY(myops, m) { + m.def("add_scalar(Tensor input, float scalar) -> Tensor", &add_scalar); +} + +// Register the implementation +TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) { + m.impl("add_scalar", &add_scalar); +} + +} // namespace myops +``` + +### Migrated Version with `STABLE_TORCH_LIBRARY` + +```cpp +// stable_kernel.cpp - Using STABLE_TORCH_LIBRARY (stable ABI) + +// (1) Don't include +// only include APIs from torch/csrc/stable, torch/headeronly and C-shims +#include +#include +#include +#include +#include +#include + +namespace myops { + +// Simple kernel that adds a scalar value to each element of a tensor +torch::stable::Tensor add_scalar(const torch::stable::Tensor& input, double scalar) { + // (2) use STD_TORCH_CHECK instead of TORCH_CHECK + STD_TORCH_CHECK( + // (3) use torch::headeronly::kFloat instead of at:kFloat + input.scalar_type() == torch::headeronly::kFloat, + "Input must be float32"); + + // (4) Use stable ops namespace instead of input.add + return torch::stable::add(input, scalar); +} + +// (5) Add Boxed wrapper required for STABLE_TORCH_LIBRARY +void boxed_add_scalar(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + // Extract arguments from stack using `to` + auto input = to(stack[0]); + auto scalar = to(stack[1]); + + // Call the actual kernel + auto result = add_scalar(input, scalar); + + // Put result back on stack using `from()` + // Stack slot 0 now holds the return value + stack[0] = from(result); +} + +// (6) Register the operator using STABLE_TORCH_LIBRARY +STABLE_TORCH_LIBRARY(myops, m) { + m.def("add_scalar(Tensor input, float scalar) -> Tensor", &boxed_add_scalar); +} + +// (7) Register the implementation using STABLE_TORCH_LIBRARY_IMPL +STABLE_TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) { + m.impl("add_scalar", &boxed_add_scalar); +} + +} // namespace myops +``` + ## How are objects passed across the ABI boundary when interacting with the dispatcher? @@ -109,6 +211,7 @@ There are two invariants for the stack: a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack. b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references. +(stack-based-apis)= ### Stack-based APIs The above is relevant in two places: From 096c9356def3cb1d2c50c64dad1cfced707f922d Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Thu, 6 Nov 2025 20:29:29 +0000 Subject: [PATCH 147/651] [CUDA][cuBLASLt] addmm -- enable 2D bias in the Lt path when followed by an activation (#165548) As per title. This one is based off [#163955](https://github.com/pytorch/pytorch/pull/163955), but I will rebase once it is merged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165548 Approved by: https://github.com/eqy --- aten/src/ATen/cuda/CUDABlas.cpp | 25 +++++++---- aten/src/ATen/native/cuda/Blas.cpp | 68 ++++++++++++++++++++++++------ test/test_linalg.py | 14 +++--- 3 files changed, 76 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 20f235076220f..9a55b058001da 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1597,7 +1597,7 @@ bool gemm_and_bias( } using opmath_t = at::opmath_type; - opmath_t beta_val = 0; // bias is added in epilogue + opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr cudaDataType_t abType = CUDA_R_32F; cudaDataType_t cType = CUDA_R_32F; @@ -1686,15 +1686,22 @@ bool gemm_and_bias( _syncCurrentWithCarveoutStream(stream, true); } #endif - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; - if (activation == GEMMAndBiasActivationEpilogue::RELU) { - epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; - } else if (activation == GEMMAndBiasActivationEpilogue::GELU) { - epilogue = CUBLASLT_EPILOGUE_GELU_BIAS; - } + const auto epilogue = [&]() -> cublasLtEpilogue_t { + // The cuBLAS documentation indicates that + // *__BIAS = *_, + // but we keep it verbose here for clarity. + switch (activation) { + case GEMMAndBiasActivationEpilogue::RELU: + return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU; + case GEMMAndBiasActivationEpilogue::GELU: + return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU; + default: + return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT; + } + }(); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue); - if (bias != nullptr) { - computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue); + if (bias) { computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias); } diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 186f7d8a6a78a..2754d70cac013 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -147,14 +147,24 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) { /* * Check whether for the given input we want to enable the Lt interface */ -static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) { +static bool isInputCompliesAddmmCudaLt( + Tensor& result, + const Tensor& self, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + Activation activation +) { + #ifdef USE_ROCM // Implies 2D bias which we currently not send through Lt. // TODO: this check is done pre col-major input preparation, // so, this condition can be ralexed in cases when a col-major // copy of result is needed. - if (result.is_same(self)) { + if (self.is_same(result) || self.dim() == 2) { return false; } + #endif #if defined(USE_ROCM) && ROCM_VERSION == 60400 // hipblaslt TT fp32 regression on ROCm 6.4, cannot use @@ -169,13 +179,33 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const #if defined(CUDA_VERSION) || defined(USE_ROCM) const auto scalar_type = mat1.scalar_type(); return (beta.toComplexDouble() == 1.0 + // NOTE: row-major result is important when bias is 1D. + // This is because Lt broadcasts 1D bias over the columns + // while the aten::addmm API broadcasts it over the rows, + // and this is in conjuction with the data preparation + // procedure that does not transpose arguments with + // col-major result. For col-major result we need + // to explicitly transpose the problem so that bias is + // correctly applied. + // TODO: enable col-major result if needed. + // TODO: no need to check result's layout when + // !result.is_same(self) and self.dim() == 2, because + // self needs to be copied into result and the bias ptr + // will be ignored. && result.dim() == 2 && result.is_contiguous() - // Conditions for bias to be fusable && ( - self.is_contiguous() && - // NOTE: fine to have 1-len dims to the left from the right-most one - (self.dim() == 1 || self.squeeze().dim() == 1) && - self.sizes().back() == mat2_sizes[1] + ( // Conditions for bias to be fusable -- implies direct Lt path without copies. + self.is_contiguous() && + // NOTE: fine to have 1-len dims to the left from the right-most one + (self.dim() == 1 || self.squeeze().dim() == 1) && + self.sizes().back() == mat2_sizes[1] + ) + || ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self), + // and we need to copy self into result otherwise, so the self's layout becomes irrelevant. + // See also TODO from above. + activation != Activation::None && // Lt is faster when activation is fused + (self.dim() == 2 && at::is_expandable_to(self.sizes(), {mat1_sizes[0], mat2_sizes[1]})) + ) ) && ( // some dtype restrictions #ifndef USE_ROCM @@ -270,7 +300,16 @@ bool launchGemmAndBiasCublasLt( const Scalar& alpha, Activation activation = Activation::None ) { - const auto* self_ptr = self.const_data_ptr(); + // We apply bias in the epilogue only when it is 1D, + // or when it can be squeezed to 1D. + // self_ptr == nullptr implies ignore bias epilogue + // and use standard gemm-like API. + const auto* self_ptr = [&]() -> auto { + if (self.dim() == 1 || self.squeeze().dim() == 1) { + return self.const_data_ptr(); + } + return static_cast(nullptr); + }(); const auto tuning_ctx = at::cuda::tunable::getTuningContext(); if (tuning_ctx->IsTunableOpEnabled()) { @@ -356,7 +395,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt; #endif // Condition on the input - disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt; + disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt; // } at::ScalarType scalar_type = mat1.scalar_type(); @@ -366,19 +405,20 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma if (!result.is_same(self)) { at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]}); + // We use bias ptr in the Lt path only when bias is 1D + const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt; const auto self_maybe_expanded = [&]() -> c10::MaybeOwned { - if (disable_addmm_cuda_lt) { - // When in non-Lt path we do expand self even before + if (!use_bias_ptr_lt) { + // We do expand self even before // check for beta != 0.0 to make sure that // test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_* // runs green. return expand_size(self, result.sizes(), "addmm"); } - // copy next, should broadcast return c10::MaybeOwned::borrowed(self); }(); - // We copy bias when in the non-Lt path - if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) { + // We do not copy bias only when we need the bias ptr + if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) { // NOTE: self should broadcast over result at::native::copy_(result, *self_maybe_expanded); } diff --git a/test/test_linalg.py b/test/test_linalg.py index 41a223763d474..9168964369920 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7329,11 +7329,11 @@ def _test_addmm_impl(self, func, activation, device, dtype): m2 = torch.randn(50, 25, device=device).to(dtype) self._test_addmm_addmv(func, M, m1, m2, activation=activation) - # vector-shaped bias (or with 1-len dims on the left from the leading dim) + # vector (or with 1-len dims in shape[:-1])/matrix-shaped bias # and beta=1 result in epilogue fusion in CUDA V = torch.randn(25, device=device).to(dtype) - self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) - self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation) + for c in (V, V.unsqueeze(0), M): + self._test_addmm_addmv(func, c, m1, m2, beta=1, activation=activation) # Test 0-strided M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) @@ -7357,12 +7357,10 @@ def maybe_transpose(cond, m): M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype)) m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype)) m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype)) - self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) - if t1: - # use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) - self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) - self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,) + for c, beta in itertools.product((M, V, V.unsqueeze(0)), (0, 1)): + # beta=1 to test epilogue fusions with either vector or matrix input + self._test_addmm_addmv(func, c, m1, m2, beta=beta, transpose_out=t4, activation=activation) @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) From d19f36bea1c47d47ace6bae43d3b34dd858de145 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Thu, 6 Nov 2025 20:38:56 +0000 Subject: [PATCH 148/651] [BE][Ez]: Update fmtlib submodule to 12.1.0 (#166983) Fixed some compiler idiosyncrasies, improves CPP support, bugfixes, and performance optimizations. This is a header only minor library change so should be low risk and improve the performance of our formatting/loggers. Also allows fmtlib to be used in more constexpr contexts. Full changelog here: https://github.com/fmtlib/fmt/releases/tag/12.1.0 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166983 Approved by: https://github.com/atalman --- third_party/fmt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/fmt b/third_party/fmt index e424e3f2e607d..407c905e45ad7 160000 --- a/third_party/fmt +++ b/third_party/fmt @@ -1 +1 @@ -Subproject commit e424e3f2e607da02742f73db84873b8084fc714c +Subproject commit 407c905e45ad75fc29bf0f9bb7c5c2fd3475976f From 888958ad6c0f51123a9bf51e931cd45d16f5d267 Mon Sep 17 00:00:00 2001 From: jmaczan Date: Thu, 6 Nov 2025 21:00:44 +0000 Subject: [PATCH 149/651] Prevent torch._check causing graph breaks (#164676) Handle `torch._check` in `TorchInGraphFunctionVariable.call_function`. Basically, it has two arguments - a predicate (bool) and a message (callable). If predicate is a constant, evaluate `torch._check`. If predicate is true, it just will compile and nothing happens. If predicate is false, `torch._check` will raise an exception. If predicate is not constant, we manually emit a proxy. I tried to build as_proxy() inside NestedUserFunctionVariable, but failed to, that's why I create it here. I try to extract message. If it's a function, I retrieve it. If not, set it to None. Maybe we could extract it if message is a closure, but not sure how Fixes #163668 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164676 Approved by: https://github.com/williamwen42, https://github.com/mlazos Co-authored-by: William Wen --- test/dynamo/test_misc.py | 164 ++++++++++++++++++++++++ torch/_dynamo/graph_break_registry.json | 13 ++ torch/_dynamo/trace_rules.py | 2 +- torch/_dynamo/variables/torch.py | 82 +++++++++++- 4 files changed, 259 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b3e9df6a25cf3..0db7043b02c21 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1428,6 +1428,170 @@ def f(x): self.assertRaises(torch._dynamo.exc.UserError, lambda: f(torch.tensor([3]))) + def test_check_compiles_when_predicate_true_and_message_has_no_closure(self): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3") + return x + 1 + + x = torch.randn(4) + torch._dynamo.maybe_mark_dynamic(x, 0) + + y = f(x) + self.assertEqual(y.shape, x.shape) + + def test_check_compiles_when_predicate_true_constant_and_message_has_no_closure( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3") + return x + 1 + + x = torch.randn(4) + + y = f(x) + self.assertEqual(y.shape, x.shape) + + def test_check_compiles_when_predicate_true_constant_and_message_None(self): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3) + return x + 1 + + x = torch.randn(4) + + y = f(x) + self.assertEqual(y.shape, x.shape) + + def test_check_compiles_when_predicate_true_and_message_None(self): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3) + return x + 1 + + x = torch.randn(4) + torch._dynamo.maybe_mark_dynamic(x, 0) + + y = f(x) + self.assertEqual(y.shape, x.shape) + + def test_check_compiles_when_predicate_true_and_message_has_global(self): + global GLOBAL_INT + GLOBAL_INT = 1 + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3") + return x + 1 + + x = torch.randn(4) + torch._dynamo.maybe_mark_dynamic(x, 0) + + y = f(x) + self.assertEqual(y.shape, x.shape) + + def test_check_raises_at_runtime_when_predicate_false_and_message_has_global(self): + global GLOBAL_INT + GLOBAL_INT = 1 + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: f"{GLOBAL_INT} is not greater than 3") + return x + 1 + + x = torch.randn(3) + torch._dynamo.maybe_mark_dynamic(x, 0) + + with self.assertRaisesRegex( + RuntimeError, f"{GLOBAL_INT} is not greater than 3" + ): + f(x) + + def test_check_raises_at_runtime_when_predicate_false_and_message_None(self): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3) + return x + 1 + + x = torch.randn(3) + torch._dynamo.maybe_mark_dynamic(x, 0) + + with self.assertRaisesRegex(RuntimeError, None): + f(x) + + def test_check_raises_at_runtime_when_predicate_false_constant_and_message_None( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3) + return x + 1 + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, None): + f(x) + + def test_check_raises_at_runtime_when_predicate_false_and_message_has_no_closure( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3") + return x + 1 + + x = torch.randn(3) + torch._dynamo.maybe_mark_dynamic(x, 0) + + with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"): + f(x) + + def test_check_raises_at_runtime_when_predicate_false_constant_and_message_has_no_closure( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: "Shape is not greater than 3") + return x + 1 + + x = torch.randn(3) + + with self.assertRaisesRegex(RuntimeError, "Shape is not greater than 3"): + f(x) + + def test_check_assert_error_at_runtime_when_predicate_false_and_message_has_closure( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3") + return x + 1 + + x = torch.randn(3) + torch._dynamo.maybe_mark_dynamic(x, 0) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()" + ): + f(x) + + def test_check_assert_error_at_runtime_when_predicate_true_and_message_has_closure( + self, + ): + @torch.compile(backend="eager", fullgraph=True) + def f(x): + torch._check(x.shape[0] > 3, lambda: f"{x.shape[0]} is not greater than 3") + return x + 1 + + x = torch.randn(4) + torch._dynamo.maybe_mark_dynamic(x, 0) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "Can't extract message from torch._check()" + ): + f(x) + def test_assert(self): @torch.compile def fn1(x): diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index c814c1ccf32bc..b21d81910abb1 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2937,5 +2937,18 @@ "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." ] } + ], + "GB0288": [ + { + "Gb_type": "Can't extract message from torch._check()", + "Context": "str(message_vt)", + "Explanation": "The second argument of torch._check() must be a functiondefined within the torch.compile regionthat does not reference a non-local variable.", + "Hints": [ + "Make sure the message function is defined in the torch.compile region.", + "Remove any closure variables, e.g. ", + "remove references to closure variable `x` in `lambda: f'{x} failed check'`", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } ] } diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 6a162350039d7..7efc62ed9a289 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -180,6 +180,7 @@ "torch.compiler.is_exporting": TorchInGraphFunctionVariable, "torch._C._to_dlpack": SkipFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, + "torch._check": TorchInGraphFunctionVariable, # We graph break on RNG state setters or getters like # `torch.get_rng_state` or `torch.set_rng_state`. These functions # are not aten operations and therefore they are completely ignored @@ -2343,7 +2344,6 @@ "torch._check_type", "torch._check_value", "torch._check_with", - "torch._check", "torch._compile._disable_dynamo", "torch._functorch.apis.chunk_vmap", "torch._functorch.batch_norm_replacement.batch_norm_without_running_stats", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index be28fe9269f44..30c1b8c2cf186 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -78,7 +78,7 @@ ) from .dicts import ConstDictVariable from .distributed import DistributedVariable, ProcessGroupVariable -from .functions import bind_args_cached +from .functions import bind_args_cached, NestedUserFunctionVariable from .lists import ListVariable, TupleVariable from .torch_function import ( can_dispatch_torch_function, @@ -1318,6 +1318,86 @@ def handle_set_default_device( return ConstantVariable.create(None) + @register(torch._check) + def handle_check(self, tx: "InstructionTranslator", *args, **kwargs): + predicate_vt = None + message_vt = None + + if args: + predicate_vt = args[0] + rest_args = args[1:] + else: + rest_args = () + + if predicate_vt is None and "cond" in kwargs: + predicate_vt = kwargs.pop("cond") + + if rest_args: + message_vt = rest_args[0] + elif "message" in kwargs: + message_vt = kwargs.pop("message") + + if predicate_vt is None: + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + (), + {}, + ), + ) + + message_eager = None + message_graph_proxy = None + if message_vt is not None: + if ( + not isinstance(message_vt, NestedUserFunctionVariable) + or message_vt.has_closure() + ): + unimplemented_v2( + gb_type="Can't extract message from torch._check()", + context=str(message_vt), + explanation=( + "The second argument of torch._check() must be a function" + "defined within the torch.compile region" + "that does not reference a non-local variable." + ), + hints=[ + "Make sure the message function is defined in the torch.compile region.", + "Remove any closure variables, e.g. " + "remove references to closure variable `x` in `lambda: f'{x} failed check'`", + *graph_break_hints.SUPPORTABLE, + ], + ) + message_eager = message_vt.get_function() + + message_graph_proxy = tx.output.register_static_attr_and_return_proxy( + "_check_message", message_eager + ) + + if predicate_vt.is_python_constant(): + self.value(predicate_vt.as_python_constant(), message_eager) + return ConstantVariable.create(None) + + predicate_proxy = predicate_vt.as_proxy() + + proxy_args: tuple[Any, ...] + if message_graph_proxy is None: + proxy_args = (predicate_proxy,) + else: + proxy_args = (predicate_proxy, message_graph_proxy) + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + proxy_args, + {}, + ), + ) + return handlers def call_function( From ab1e734cd703caae3d2db7f9661177514a397b9b Mon Sep 17 00:00:00 2001 From: shunting314 Date: Mon, 3 Nov 2025 16:52:04 -0800 Subject: [PATCH 150/651] [ez] avoid log spam when random data is generated (#166919) It's annoying to see full screen of this warning when running fx_graph_runnable files saved in tlparse. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166919 Approved by: https://github.com/eellison --- torch/_dynamo/debug_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index e16fa11ed08f6..2acf517aba92f 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -47,7 +47,7 @@ from torch.utils._content_store import ContentStoreReader, ContentStoreWriter from . import config -from .utils import clone_inputs, get_debug_dir +from .utils import clone_inputs, get_debug_dir, warn_once if TYPE_CHECKING: @@ -617,7 +617,7 @@ def storage( # way would be very mysterious! Would have been better # not to store device in the serialized format... return storage - log.warning("could not load %s, generating random data instead", storage_hash) + warn_once(f"could not load {storage_hash}, generating random data instead") shape = (nbytes // dtype_hint.itemsize,) stride = _stride_or_default(None, shape=shape) return rand_strided(shape, stride, dtype_hint, device).untyped_storage() From 78827c5e002581f8305f03759ff6aa8579b2062a Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 5 Nov 2025 06:49:26 -0800 Subject: [PATCH 151/651] Distributed Autotuning (#163369) This is the initial prototype of distributed autotuning. It's intended to be a basis for iteration rather than the final end product. Currently when we run a SPMD program we compile the ranks independently. As a result the autotuning is repeated on every rank. So for a 8-GPU program with 8 matmul operators we'll autotune 64 (8*8) times. Distributed autotuning uses collectives to distribute the autotuning across the ranks so each rank autotunes 1/worldsize the total operators. So in our 8-GPU example we would only perform 8 autotunes total (one on each rank) rather than 64. There are several advantages: 1. Faster autotuning times - each CPU/GPU does less work total 2. Better determinism - currently it's possible for two ranks to choose different algorithms for the same operator. With distributed autotuning we choose the algorithm once for the entire program. Results: In testing using llama3 8B on torchtitan max-autotune time was reduced from 52s -> 26s and exhaustive-autotuning was reduced from 2009s -> 613s. Usage: The feature is controlled by the environment variable TORCHINDUCTOR_DISTRIBUTED_AUTOTUNE. Co-authored-by: @PaulZhang12 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163369 Approved by: https://github.com/PaulZhang12 --- test/distributed/test_dynamo_distributed.py | 114 ++++++ torch/_inductor/compile_fx.py | 8 +- torch/_inductor/config.py | 8 + torch/_inductor/distributed_autotune.py | 386 ++++++++++++++++++++ torch/_inductor/kernel/mm.py | 7 +- torch/_inductor/scheduler.py | 21 +- torch/_inductor/virtualized.py | 15 + 7 files changed, 551 insertions(+), 8 deletions(-) create mode 100644 torch/_inductor/distributed_autotune.py diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b75fb91379f9c..61186034c746f 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -2,6 +2,7 @@ import contextlib import copy import functools +import logging import random import unittest from contextlib import contextmanager @@ -51,6 +52,9 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton +log = logging.getLogger(__name__) + + def reset_rng_state(): torch.manual_seed(1337) random.seed(1337) @@ -1200,6 +1204,116 @@ def f(x): for r in res[1:]: self.assertEqual(res[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @patch.object(torch._dynamo.config, "enable_compiler_collectives", True) + @patch.object(torch._inductor.config, "max_autotune_gemm", True) + @patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True) + def test_multiproc_autotune(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(a, b, c): + res = ( + torch.sum((a @ b) + 1.0) + + torch.sum(torch.relu(b @ c)) + + torch.sum(c @ a) + ) + + return res + + a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16) + b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16) + c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16) + + try: + f(a, b, c) + except Exception: + log.exception("Caught exception running f") + raise + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + print(f"Result from {self.rank} is {f(a, b, c)}") + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @patch.object(torch._dynamo.config, "enable_compiler_collectives", True) + @patch.object(torch._inductor.config, "max_autotune_gemm", True) + @patch.object(torch._inductor.config, "distributed_max_autotune_gemm", True) + def test_multiproc_autotune_dynamic_shapes(self): + with _dynamo_dist_per_rank_init(self.rank, self.world_size): + torch._dynamo.utils.clear_compilation_metrics() + + @torch.compile() + def f(a, b, c): + res = ( + torch.sum((a @ b) + 1.0) + + torch.sum(torch.relu(b @ c)) + + torch.sum(c @ a) + ) + + return res + + a = torch.randn(1024, 1024, device=self.rank, dtype=torch.bfloat16) + b = torch.randn(1024, 2048, device=self.rank, dtype=torch.bfloat16) + c = torch.randn(2048, 1024, device=self.rank, dtype=torch.bfloat16) + + # Mark tensors as dynamic on dimension 0 + torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.mark_dynamic(a, 1) + torch._dynamo.mark_dynamic(b, 0) + torch._dynamo.mark_dynamic(b, 1) + torch._dynamo.mark_dynamic(c, 0) + torch._dynamo.mark_dynamic(c, 1) + + try: + f(a, b, c) + except Exception: + log.exception("Caught exception running f") + raise + + metrics = torch._dynamo.utils.get_compilation_metrics() + res = [None] * self.world_size + torch.distributed.all_gather_object(res, len(metrics)) + for r in res[1:]: + self.assertEqual(res[0], r) + + print(f"Result from {self.rank} is {f(a, b, c)}") + + # Store the initial compilation count + initial_compile_count = len(metrics) + + # # Test with different sizes to ensure dynamic shapes work without recompilation + a2 = torch.randn(512, 512, device=self.rank, dtype=torch.bfloat16) + b2 = torch.randn(512, 2048, device=self.rank, dtype=torch.bfloat16) + c2 = torch.randn(2048, 512, device=self.rank, dtype=torch.bfloat16) + + try: + result2 = f(a2, b2, c2) + print(f"Result2 from {self.rank} is {result2}") + except Exception: + log.exception("Caught exception running f with different sizes") + raise + + # Verify no recompilation occurred + metrics_after = torch._dynamo.utils.get_compilation_metrics() + final_compile_count = len(metrics_after) + self.assertEqual( + initial_compile_count, + final_compile_count, + "Expected no recompilation with dynamic shapes", + ) + + # Verify all ranks have the same compilation count + res_after = [None] * self.world_size + torch.distributed.all_gather_object(res_after, final_compile_count) + for r in res_after[1:]: + self.assertEqual(res_after[0], r) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_get_pg_attr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 94df451b5f1d0..8ff19b8721067 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -104,7 +104,7 @@ from ..fx._lazy_graph_module import _use_lazy_graph_module from ..fx.graph import _PyTreeCodeGen from ..utils._triton import has_triton -from . import config, metrics +from . import config, distributed_autotune, metrics from .codegen.common import get_wrapper_codegen_for_device, init_backend_registration from .debug import DebugContext from .decomposition import select_decomp_table @@ -1431,7 +1431,11 @@ def codegen_and_compile( # We are going to start code generating runtime asserts, so make sure # you don't start adding new ones in the lowering process graph.freeze_runtime_asserts() - with V.set_graph_handler(graph), V.set_extern_kernel_nodes([]): + with ( + V.set_graph_handler(graph), + V.set_extern_kernel_nodes([]), + distributed_autotune.graph_context(), + ): graph.run(*example_inputs) output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = [] if graph.graph_outputs is not None: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8b996967749f5..10e3d2bb5211a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -447,6 +447,14 @@ def prologue_fusion_enabled() -> bool: justknob="pytorch/inductor:use_experimental_benchmarker", ) +# Enable distributed autotuning. When this is enabled we will distribute the +# autotuning across distributed ranks in the same program group - so instead of +# each rank autotuning every kernel they only autotune 1/world size kernels and +# then share the results. +distributed_max_autotune_gemm = ( + os.environ.get("TORCHINDUCTOR_DISTRIBUTED_MAX_AUTOTUNE_GEMM") == "1" +) + # enable slow autotuning passes to select algorithms max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" diff --git a/torch/_inductor/distributed_autotune.py b/torch/_inductor/distributed_autotune.py new file mode 100644 index 0000000000000..af2d5bb9e9f11 --- /dev/null +++ b/torch/_inductor/distributed_autotune.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import contextlib +import dataclasses +from typing import Any, TYPE_CHECKING, Union +from unittest.mock import patch + +import sympy + +import torch._logging +import torch.distributed as dist +import torch.fx +from torch.utils._ordered_set import OrderedSet + +from . import config, select_algorithm +from .ir import ( + Buffer, + ChoiceCaller, + Layout, + MultiTemplateBuffer, + OperationBuffer, + ShapeAsConstantBuffer, + StorageBox, + TensorBox, +) +from .kernel_inputs import KernelInputs, MMKernelInputs +from .scheduler import SchedulerNode +from .virtualized import NullHandler, V + + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + +_DISTRIBUTED_AUTOTUNE_KEY = "distributed_autotune" + +_AUTOTUNE_PG: dist.ProcessGroup | None = None + + +@dataclasses.dataclass +class _DistributedAutotuneState: + """ + State used to track autotuning during a graph_context() + """ + + # This is the next operator index. Used to figure out which rank should do + # the autotuning. + autotuned_index: int = 0 + + # For debugging - used to make sure that we autotune the same number of + # local operators that we expected to. + autotuned_local_count: int = 0 + + +@dataclasses.dataclass +class _DistributedAutotuneInfo: + index: int + local: bool + + +def get_autotune_pg() -> dist.ProcessGroup | None: + if dist.is_available() and dist.is_initialized(): + global _AUTOTUNE_PG + if _AUTOTUNE_PG is None: + _AUTOTUNE_PG = dist.distributed_c10d._new_group_with_tag( + pg_tag="pt2_distributed_autotune_pg" + ) + return _AUTOTUNE_PG + + return None + + +def schedule(scheduler: torch._inductor.scheduler.Scheduler) -> None: + """ + Finish the distributed autotuning by propagating the autotuning results + between the ranks and then replacing the placeholder with the real Buffer. + """ + assert config.distributed_max_autotune_gemm + autotune_results = _autotune_local_nodes(scheduler) + choices_by_index = _sync(autotune_results) + _autotune_remote_nodes(scheduler, choices_by_index) + + +@contextlib.contextmanager +def graph_context() -> Generator[None, None, None]: + """ + Wrapped around processing a graph, sets up figuring out which ranks tune + which shapes. + """ + assert not isinstance( + V.get_distributed_autotune_state(check_poisoned=False), # type: ignore[call-arg] + _DistributedAutotuneState, + ) + V.set_distributed_autotune_state(_DistributedAutotuneState()) + try: + yield + finally: + V.set_distributed_autotune_state(NullHandler()) + + +def maybe_autotune_remote( + name: str, choices: list[ChoiceCaller], inputs: list[Buffer], layout: Layout +) -> TensorBox | ShapeAsConstantBuffer | None: + """ + Used by an op (like `mm`) to determine if the op should be autotuned + locally (returns None) or remotely (returns a placeholder Buffer). + """ + if not config.distributed_max_autotune_gemm: + return None + + if not (autotune_pg := get_autotune_pg()): + return None + + if len(choices) <= 1: + return None + + state = V.distributed_autotune_state + index = state.autotuned_index + state.autotuned_index += 1 + local = index % autotune_pg.size() == autotune_pg.rank() + + V.current_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] = _DistributedAutotuneInfo( + index, local + ) + if local: + state.autotuned_local_count += 1 + return None + + return torch._inductor.ir.TensorBox.create( + _DistributedAutotuneBuffer(name, inputs, layout) + ) + + +class _DistributedAutotuneBuffer(MultiTemplateBuffer): + """ + A MultiTemplateBuffer which represents a kernel being autotuned on a + different rank. When `schedule` is called this will be replaced by the + "real" buffer. + """ + + # Name of the kernel being autotuned. + _kernel_name: str + + def __init__( + self, + kernel_name: str, + inputs: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + layout, + inputs, + choice_timings_fn=self._dummy_choice_timings, + unfiltered_choices=[], + allowed_prologue_inps=OrderedSet({}), + ) + + self._kernel_name = kernel_name + + def _dummy_choice_timings( + self, _hint_override: int | None + ) -> dict[ChoiceCaller, float]: + # This should never get called. It means that a remote autotune was + # scheduled but never filled in. + raise NotImplementedError + + def autotune(self, ser_choice: _SerializedChoice) -> TensorBox: + """ + Given a _SerializedChoice (autotune results from another rank) + compute the final TensorBox. + """ + + from .select_algorithm import autotune_select_algorithm + + with patch.object(V.graph, "scheduler", None): + kernel_inputs = MMKernelInputs([*self.original_inputs]) + assert isinstance(self.layout, Layout) + choice = ser_choice.get_choice(self.layout, kernel_inputs) + buffer = autotune_select_algorithm( + self._kernel_name, + [choice], + kernel_inputs.nodes(), + self.layout, + ) + assert isinstance(buffer, TensorBox) + return buffer + + +# Can we make this async? +def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoice]: + """ + Perform the all_gather to collect the autotune results from all the ranks. + """ + + autotune_pg = get_autotune_pg() + assert autotune_pg + + # Perform allgather + all_states: list[list[_SerializedChoice]] = [None] * autotune_pg.size() # type: ignore[list-item] + torch.distributed.all_gather_object(all_states, autotune_results, group=autotune_pg) + + node_count = sum(len(x) for x in all_states) + # It's faster to briefly lie about the type than to unzip the results and append. + choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item] + + check_count = 0 + for i, other_results in enumerate(all_states): + for choice in other_results: + assert isinstance(choice, _SerializedChoice) + assert choices_by_index[choice.index] is None + choices_by_index[choice.index] = choice + check_count += 1 + + assert node_count == check_count, f"count mismatch: {node_count} != {check_count}" + return choices_by_index + + +class _SerializedChoice: + """ + This is a serializer for the autotune choice. KernelTemplateChoice can't + be serialized directly (the template and inputs prevent this) so we need to + serialize it by parts and reconstruct later on. + """ + + def __init__(self, index: int, choice: ChoiceCaller) -> None: + self.index = index + self.template_uid = _SerializedChoice._template_uid_from_choice(choice) + self.kwargs = self._compute_kwargs(choice.description) + + def get_choice(self, layout: Layout, inputs: KernelInputs) -> ChoiceCaller | None: + """ + Deserialize the ChoiceCaller and return it. + """ + + template = self._template_from_uid() + + kwargs = {**self.kwargs} + if "BLOCK_K" in kwargs: + # TODO: Do we really need to externally compute this value? If it's + # needed I'm surprised it's not just part of the original template + # description. + # This needs the actual 'k' to figure out the value. + k = inputs.nodes()[0].get_size()[1] + kwargs["EVEN_K"] = sympy.gcd(k, kwargs["BLOCK_K"]) == kwargs["BLOCK_K"] + + extra_kwargs: dict[str, Any] = {} + from .kernel_template_choice import ( + DictKernelTemplateParams, + KernelTemplateChoice, + ) + + params = DictKernelTemplateParams(kwargs) + ktc = KernelTemplateChoice(template, params, extra_kwargs, layout, inputs) + return ktc.choice + + @staticmethod + def _compute_kwargs(description: str) -> dict[str, Union[int, str, bool]]: + """ + Given a template description turn it into input kwargs. + """ + if not description: + return {} + + # TODO: It seems like it would be better if the template could provide + # this directly instead of having to parse a string. + kwargs: dict[str, Union[int, str, bool]] = {} + for cfg in description.split(","): + key, val = cfg.split("=", 1) + key, val = key.strip(), val.strip() + if val == "True": + kwargs[key] = True + elif val == "False": + kwargs[key] = False + elif val.isdigit(): + kwargs[key] = int(val) + else: + assert val.startswith("'") and val.endswith("'") + kwargs[key] = val[1:-1] + return kwargs + + @staticmethod + def _template_uid_from_choice(choice: ChoiceCaller) -> str: + """ + Given a ChoiceCaller figure out which template represents it. This + is reversed by _template_from_uid(). + """ + + # We need a better way to do this - right now we need to add each + # supported template directly. + if isinstance(choice, select_algorithm.ExternKernelCaller): + if choice.choice.name == "mm": + return "torch._inductor.kernel.mm.aten_mm" + else: + raise RuntimeError(f"TODO: kernel {choice.choice.name!r}") + elif isinstance(choice, select_algorithm.TritonTemplateCaller): + return "torch._inductor.kernel.mm.mm_template" + else: + raise RuntimeError(f"TODO: {type(choice)}") + + def _template_from_uid(self) -> Any: + """ + See _template_uid_from_choice(). + """ + parts = self.template_uid.split(".") + obj = globals()[parts[0]] + for k in parts[1:]: + obj = getattr(obj, k) + return obj + + +def _autotune_local_nodes( + scheduler: torch._inductor.scheduler.Scheduler, +) -> list[_SerializedChoice]: + """ + Go through the nodes in the scheduler and autotune the kernels which + should be autotuned by this rank. + """ + + autotune_results: list[_SerializedChoice] = [] + + for node in scheduler.nodes: + if not isinstance(node, SchedulerNode): + continue + + if (inner_node := node.node) is None: + continue + + if isinstance(inner_node, _DistributedAutotuneBuffer): + # This is marked for remote autotuning. + continue + + if not isinstance(inner_node, MultiTemplateBuffer): + continue + + if (origin_node := inner_node.origin_node) is None: + continue + + if (meta := origin_node.meta) is None: + continue + + info = meta.get(_DISTRIBUTED_AUTOTUNE_KEY) + if info is None: + continue + + assert info.local + + # We force autotuning here + # Still takes advantage of async precompile + # We need all the configs before fusion + min_choice, _ = inner_node.get_min_choice() + + choice = _SerializedChoice(info.index, min_choice) + autotune_results.append(choice) + + state = V.distributed_autotune_state + assert len(autotune_results) == state.autotuned_local_count, ( + f"incorrect local autotuned nodes found ({len(autotune_results)} != {state.autotuned_local_count})" + ) + return autotune_results + + +def _autotune_remote_nodes( + scheduler: torch._inductor.scheduler.Scheduler, + choices_by_index: Sequence[_SerializedChoice], +) -> None: + """ + Go through the nodes in the scheduler and autotune the nodes that were + autotuned on remote ranks. + """ + + for i, node in enumerate(scheduler.nodes): + if isinstance(node, SchedulerNode) and isinstance( + (dist_node := node.node), _DistributedAutotuneBuffer + ): + assert dist_node.origin_node is not None + info = dist_node.origin_node.meta[_DISTRIBUTED_AUTOTUNE_KEY] + out_tensorbox = dist_node.autotune(choices_by_index[info.index]) + + out_storage = out_tensorbox.data + assert isinstance(out_storage, StorageBox) + out_buffer = out_storage.data + assert isinstance(out_buffer, OperationBuffer) + + assert out_buffer.layout == dist_node.layout + + scheduler._replace_node(out_buffer, dist_node, i, node) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 6a8657f86bf03..986ceb4405a14 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -19,7 +19,7 @@ from torch.nn.functional import ScalingType # type: ignore[attr-defined] from torch.torch_version import TorchVersion -from .. import config as inductor_config +from .. import config as inductor_config, distributed_autotune from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate @@ -1315,6 +1315,11 @@ def _to_dtype(x): # The future will be awaited at scheduling time in select_algorithm.py best_config_future = gen_best_config(mat1, mat2) + if box := distributed_autotune.maybe_autotune_remote( + name, choices, kernel_inputs.nodes(), layout + ): + return box + return autotune_select_algorithm( name, choices, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2930a33b465a6..d7e3ed5a529d1 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -449,7 +449,6 @@ class SchedulerDonatedBuffer(SchedulerBuffer): class BaseSchedulerNode: ancestors: OrderedSet[str] - debug_device_str: Callable[[BaseSchedulerNode], list[str]] group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]] last_usage: OrderedSet[str] # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode. @@ -461,21 +460,26 @@ class BaseSchedulerNode: max_order: int mpi_node: MemoryPlanningInfoForNode mutation_renames: dict[str, str] - node: Optional[ir.Operation] + node: Optional[ir.Operation] = None outputs: list[SchedulerBuffer] outputs_by_name: dict[str, SchedulerBuffer] override_estimated_runtime: Optional[float] = None read_writes: dependencies.ReadWrites unmet_dependencies: OrderedSet[Dep] + written: bool = False def __init__(self, scheduler: Scheduler) -> None: - self.scheduler = scheduler - self.debug_device_str = lambda *args, **kwargs: [] + self.scheduler: Scheduler = scheduler + self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = ( + lambda *args, **kwargs: [] + ) def _init_from_node(self, node: ir.Operation) -> None: self.node = node self.ancestors = OrderedSet() - self.last_usage = OrderedSet() # buffers that won't be used after this kernel + self.last_usage = OrderedSet[ + str + ]() # buffers that won't be used after this kernel self.written = False self.outputs = [ SchedulerBuffer( @@ -2643,6 +2647,12 @@ def _init(self, nodes: list[ir.Operation]) -> None: if config._pre_fusion_custom_pass is not None: self.nodes = config._pre_fusion_custom_pass(self.nodes) + if config.distributed_max_autotune_gemm: + from . import distributed_autotune + + distributed_autotune.schedule(self) + self.compute_ancestors() + self.nodes = self.fuse_nodes(self.nodes) if config._post_fusion_custom_pass is not None: self.nodes = config._post_fusion_custom_pass(self.nodes) @@ -3515,6 +3525,7 @@ def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]: new_scheduler_node.min_order = node.min_order new_scheduler_node.max_order = node.max_order + new_scheduler_node.ancestors = node.ancestors new_scheduler_node.last_usage = node.last_usage def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool: diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index b08cd5059baf8..f45e372e2b3a3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -86,6 +86,8 @@ from torch._inductor.loop_body import InterpreterShim from torch._subclasses import FakeTensorMode + from .distributed_autotune import _DistributedAutotuneState + threadlocal = local() T = TypeVar("T") @@ -201,6 +203,9 @@ def get_index_dtype_as_torch_dtype(self): _local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( "local_buffer_context", NullHandler ) +_distributed_autotune_state: Virtualized[_DistributedAutotuneState] = Virtualized( + "distributed_autotune_state", NullHandler +) def _choices_default(): @@ -370,6 +375,12 @@ class _V: set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler set_choices_handler: Callable[[Any], Any] = _choices._set_handler + set_distributed_autotune_state: Callable[[Any], Any] = ( + _distributed_autotune_state._set_handler + ) + get_distributed_autotune_state: Callable[[], Any] = ( + _distributed_autotune_state._get_handler + ) @property def ops(self) -> OpsHandler[Any]: @@ -429,5 +440,9 @@ def local_buffer_context(self): def choices(self) -> InductorChoices: return _choices._get_handler() + @property + def distributed_autotune_state(self): + return _distributed_autotune_state._get_handler() + V = _V() From d144382dc96f109a6254c38734779e0a09fb7134 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 6 Nov 2025 21:21:40 +0000 Subject: [PATCH 152/651] Move enrich_profiler_metadata config import out of gm.recompile() (#167114) Fixes T243967987 Move `enrich_profiler_metadata` from `torch._dynamo.config` to `torch.fx.experimental._config`. We cannot import anything inside recompile(), it made some perf regress internally. We move the config so we can import it at the top of `graph_module.py` without causing any circular import. We also cannot delete the old config right now because some internal tests rely on copies of the old `graph_module.py` cpp file in unit tests. But I think we should be able to delete the old config soon after this PR lands. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167114 Approved by: https://github.com/angelayi --- test/test_fx.py | 6 +++--- torch/_dynamo/config.py | 7 ++----- torch/fx/experimental/_config.py | 8 +++++++- torch/fx/graph_module.py | 10 ++++++---- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 3ad21e64c8ce2..7b075c7f73381 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4251,7 +4251,7 @@ def fn(a, b, c, d): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_stack_trace_augmentation(self): """ Test that map_recorded_events_to_aten_ops_with_stack_trace correctly @@ -4307,7 +4307,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_multiple_modules(self): """ Test that multiple compiled modules under the same profiler session @@ -4351,7 +4351,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_nested_graph_modules(self): """ Test that nested graph modules (e.g., graph modules calling subgraphs) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0c95408401c79..66142b196d630 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -739,11 +739,8 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None -# Experimental: If True, graph module will register fx metadata during recompile() -enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] - default=False, - env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", -) +# Deprecated! Please use the config in torch/fx/experimental/_config instead. +enrich_profiler_metadata: bool = False if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index ce4296b6410c9..a537978db3834 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -2,6 +2,8 @@ import sys from typing import Optional +from torch.utils._config_module import Config, install_config_module + # [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. no_data_dependent_graph_break = ( @@ -100,7 +102,11 @@ # Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. skip_dtype_check_in_meta_registrations = False -from torch.utils._config_module import install_config_module +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) install_config_module(sys.modules[__name__]) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..ab33d7bf321c9 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -20,6 +20,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_importer from ._compatibility import compatibility +from .experimental import _config as fx_experimental_config from .graph import ( _BoxedCodeGen, _custom_builtins, @@ -858,14 +859,15 @@ def recompile(self) -> PythonCode: called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ + # Do not import anything inside recompile, it might slow down the + # function and cause perf regression. Import outside of the method instead. if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - from torch._dynamo import config as dynamo_config - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata + root_module="self", + record_func=fx_experimental_config.enrich_profiler_metadata, ) self._code = python_code.src self._lineno_map = python_code._lineno_map @@ -874,7 +876,7 @@ def recompile(self) -> PythonCode: cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - if dynamo_config.enrich_profiler_metadata: + if fx_experimental_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation node_metadata: dict[int, dict[str, Any]] = {} for i, node in enumerate(self._graph.nodes): From c90a976370945af052bb7b0db86240fa6f321cd6 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Thu, 6 Nov 2025 21:31:54 +0000 Subject: [PATCH 153/651] Update pythoncapi_compat.h (#167138) Update to commit 44c8e14bbbb5d5135ae90957036a61397e4df577. Should slightly simplify https://github.com/pytorch/pytorch/pull/166342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167138 Approved by: https://github.com/albanD --- torch/csrc/utils/python_compat.h | 8 - torch/csrc/utils/pythoncapi_compat.h | 1420 +++++++++++++++++++++++++- 2 files changed, 1408 insertions(+), 20 deletions(-) diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 16308dad4421d..8488d5d0917b5 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) { #endif } -// Provided by CPython but getting the header for them is very hard -#if IS_PYTHON_3_11_PLUS -// NOLINTNEXTLINE(readability-redundant-declaration) -PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self); -#else -extern void _PyWeakref_ClearRef(PyWeakReference* self); -#endif - #ifdef __cplusplus } #endif diff --git a/torch/csrc/utils/pythoncapi_compat.h b/torch/csrc/utils/pythoncapi_compat.h index 05e80b5ee8607..bb45c18531106 100644 --- a/torch/csrc/utils/pythoncapi_compat.h +++ b/torch/csrc/utils/pythoncapi_compat.h @@ -7,7 +7,7 @@ // https://github.com/python/pythoncapi_compat // // Latest version: -// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h +// https://raw.githubusercontent.com/python/pythoncapi-compat/main/pythoncapi_compat.h // // SPDX-License-Identifier: 0BSD @@ -19,11 +19,15 @@ extern "C" { #endif #include +#include // offsetof() // Python 3.11.0b4 added PyFrame_Back() to Python.h #if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) # include "frameobject.h" // PyFrameObject, PyFrame_GetBack() #endif +#if PY_VERSION_HEX < 0x030C00A3 +# include // T_SHORT, READONLY +#endif #ifndef _Py_CAST @@ -33,11 +37,13 @@ extern "C" { // Static inline functions should use _Py_NULL rather than using directly NULL // to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, // _Py_NULL is defined as nullptr. -#if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ - || (defined(__cplusplus) && __cplusplus >= 201103) -# define _Py_NULL nullptr -#else -# define _Py_NULL NULL +#ifndef _Py_NULL +# if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ + || (defined(__cplusplus) && __cplusplus >= 201103) +# define _Py_NULL nullptr +# else +# define _Py_NULL NULL +# endif #endif // Cast argument to PyObject* type. @@ -45,6 +51,13 @@ extern "C" { # define _PyObject_CAST(op) _Py_CAST(PyObject*, op) #endif +#ifndef Py_BUILD_ASSERT +# define Py_BUILD_ASSERT(cond) \ + do { \ + (void)sizeof(char [1 - 2 * !(cond)]); \ + } while(0) +#endif + // bpo-42262 added Py_NewRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) @@ -68,6 +81,16 @@ static inline PyObject* _Py_XNewRef(PyObject *obj) #endif +// bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) +static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) +{ + ob->ob_refcnt = refcnt; +} +#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) +#endif + + // Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. // It is excluded from the limited C API. #if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) @@ -104,6 +127,37 @@ static inline PyObject* _Py_XNewRef(PyObject *obj) # define Py_IsFalse(x) Py_Is(x, Py_False) #endif + +// bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) +static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) +{ + ob->ob_type = type; +} +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) +static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) +{ + ob->ob_size = size; +} +#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) +#endif + + +// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) +static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + assert(frame->f_code != _Py_NULL); + return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); +} +#endif + static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) { PyCodeObject *code = PyFrame_GetCode(frame); @@ -112,6 +166,15 @@ static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) } +// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); +} +#endif + #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) { @@ -229,6 +292,26 @@ PyFrame_GetVarString(PyFrameObject *frame, const char *name) #endif +// bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || (defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030B0000) +static inline PyInterpreterState * +PyThreadState_GetInterpreter(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->interp; +} +#endif + + +// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} +#endif + #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyThreadState_GetFrameBorrow(PyThreadState *tstate) @@ -240,6 +323,35 @@ _PyThreadState_GetFrameBorrow(PyThreadState *tstate) #endif +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState* PyInterpreterState_Get(void) +{ + PyThreadState *tstate; + PyInterpreterState *interp; + + tstate = PyThreadState_GET(); + if (tstate == _Py_NULL) { + Py_FatalError("GIL released (tstate is NULL)"); + } + interp = tstate->interp; + if (interp == _Py_NULL) { + Py_FatalError("no current interpreter"); + } + return interp; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 +#if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->id; +} +#endif + // bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 #if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) static inline void PyThreadState_EnterTracing(PyThreadState *tstate) @@ -269,6 +381,27 @@ static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) #endif +// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 +// PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 +static inline PyObject* PyObject_CallNoArgs(PyObject *func) +{ + return PyObject_CallFunctionObjArgs(func, NULL); +} +#endif + + +// bpo-39245 made PyObject_CallOneArg() public (previously called +// _PyObject_CallOneArg) in Python 3.9.0a4 +// PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 +static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) +{ + return PyObject_CallFunctionObjArgs(func, arg, NULL); +} +#endif + + // bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 static inline int @@ -294,6 +427,58 @@ PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) #endif +// bpo-40024 added PyModule_AddType() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 +static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) +{ + const char *name, *dot; + + if (PyType_Ready(type) < 0) { + return -1; + } + + // inline _PyType_Name() + name = type->tp_name; + assert(name != _Py_NULL); + dot = strrchr(name, '.'); + if (dot != _Py_NULL) { + name = dot + 1; + } + + return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); +} +#endif + + +// bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. +// bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. +#if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsTracked(PyObject* obj) +{ + return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); +} +#endif + +// bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. +// bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. +#if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsFinalized(PyObject *obj) +{ + PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; + return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); +} +#endif + + +// bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) +static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { + return Py_TYPE(ob) == type; +} +#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) +#endif + + // bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. // bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. // Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal @@ -401,7 +586,7 @@ static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) return 0; } *pobj = Py_NewRef(obj); - return (*pobj != NULL); + return 1; } #endif @@ -420,6 +605,81 @@ static inline Py_ssize_t PyVectorcall_NARGS(size_t n) #endif +// gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 +static inline PyObject* +PyObject_Vectorcall(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames) +{ +#if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) + // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 + return _PyObject_Vectorcall(callable, args, nargsf, kwnames); +#else + PyObject *posargs = NULL, *kwargs = NULL; + PyObject *res; + Py_ssize_t nposargs, nkwargs, i; + + if (nargsf != 0 && args == NULL) { + PyErr_BadInternalCall(); + goto error; + } + if (kwnames != NULL && !PyTuple_Check(kwnames)) { + PyErr_BadInternalCall(); + goto error; + } + + nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); + if (kwnames) { + nkwargs = PyTuple_GET_SIZE(kwnames); + } + else { + nkwargs = 0; + } + + posargs = PyTuple_New(nposargs); + if (posargs == NULL) { + goto error; + } + if (nposargs) { + for (i=0; i < nposargs; i++) { + PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); + args++; + } + } + + if (nkwargs) { + kwargs = PyDict_New(); + if (kwargs == NULL) { + goto error; + } + + for (i = 0; i < nkwargs; i++) { + PyObject *key = PyTuple_GET_ITEM(kwnames, i); + PyObject *value = *args; + args++; + if (PyDict_SetItem(kwargs, key, value) < 0) { + goto error; + } + } + } + else { + kwargs = NULL; + } + + res = PyObject_Call(callable, posargs, kwargs); + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return res; + +error: + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return NULL; +#endif +} +#endif + + // gh-106521 added PyObject_GetOptionalAttr() and // PyObject_GetOptionalAttrString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 @@ -664,7 +924,7 @@ static inline int PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (*dict == NULL) { + if (dict == NULL || *dict == NULL) { return -1; } Py_VISIT(*dict); @@ -675,7 +935,7 @@ static inline void PyObject_ClearManagedDict(PyObject *obj) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (*dict == NULL) { + if (dict == NULL || *dict == NULL) { return; } Py_CLEAR(*dict); @@ -950,11 +1210,11 @@ static inline int PyTime_PerfCounter(PyTime_t *result) #endif // gh-111389 added hash constants to Python 3.13.0a5. These constants were -// added first as private macros to Python 3.4.0b1 and PyPy 7.3.9. +// added first as private macros to Python 3.4.0b1 and PyPy 7.3.8. #if (!defined(PyHASH_BITS) \ && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ - && PYPY_VERSION_NUM >= 0x07090000))) + && PYPY_VERSION_NUM >= 0x07030800))) # define PyHASH_BITS _PyHASH_BITS # define PyHASH_MODULUS _PyHASH_MODULUS # define PyHASH_INF _PyHASH_INF @@ -1196,6 +1456,18 @@ PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, return res; } +static inline int +PyUnicodeWriter_WriteASCII(PyUnicodeWriter *writer, + const char *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)strlen(str); + } + + return _PyUnicodeWriter_WriteASCIIString((_PyUnicodeWriter*)writer, + str, size); +} + static inline int PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, const wchar_t *str, Py_ssize_t size) @@ -1219,7 +1491,8 @@ PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, Py_ssize_t start, Py_ssize_t end) { if (!PyUnicode_Check(str)) { - PyErr_Format(PyExc_TypeError, "expect str, not %T", str); + PyErr_Format(PyExc_TypeError, "expect str, not %s", + Py_TYPE(str)->tp_name); return -1; } if (start < 0 || start > end) { @@ -1266,6 +1539,1129 @@ static inline int PyLong_GetSign(PyObject *obj, int *sign) } #endif +// gh-126061 added PyLong_IsPositive/Negative/Zero() to Python in 3.14.0a2 +#if PY_VERSION_HEX < 0x030E00A2 +static inline int PyLong_IsPositive(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == 1; +} + +static inline int PyLong_IsNegative(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == -1; +} + +static inline int PyLong_IsZero(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == 0; +} +#endif + + +// gh-124502 added PyUnicode_Equal() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyUnicode_Equal(PyObject *str1, PyObject *str2) +{ + if (!PyUnicode_Check(str1)) { + PyErr_Format(PyExc_TypeError, "first argument must be str, not %s", + Py_TYPE(str1)->tp_name); + return -1; + } + if (!PyUnicode_Check(str2)) { + PyErr_Format(PyExc_TypeError, "second argument must be str, not %s", + Py_TYPE(str2)->tp_name); + return -1; + } + +#if PY_VERSION_HEX >= 0x030d0000 && !defined(PYPY_VERSION) + PyAPI_FUNC(int) _PyUnicode_Equal(PyObject *str1, PyObject *str2); + + return _PyUnicode_Equal(str1, str2); +#elif PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) + return _PyUnicode_EQ(str1, str2); +#elif PY_VERSION_HEX >= 0x03090000 && defined(PYPY_VERSION) + return _PyUnicode_EQ(str1, str2); +#else + return (PyUnicode_Compare(str1, str2) == 0); +#endif +} +#endif + + +// gh-121645 added PyBytes_Join() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline PyObject* PyBytes_Join(PyObject *sep, PyObject *iterable) +{ + return _PyBytes_Join(sep, iterable); +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline Py_hash_t Py_HashBuffer(const void *ptr, Py_ssize_t len) +{ +#if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) + PyAPI_FUNC(Py_hash_t) _Py_HashBytes(const void *src, Py_ssize_t len); + + return _Py_HashBytes(ptr, len); +#else + Py_hash_t hash; + PyObject *bytes = PyBytes_FromStringAndSize((const char*)ptr, len); + if (bytes == NULL) { + return -1; + } + hash = PyObject_Hash(bytes); + Py_DECREF(bytes); + return hash; +#endif +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyIter_NextItem(PyObject *iter, PyObject **item) +{ + iternextfunc tp_iternext; + + assert(iter != NULL); + assert(item != NULL); + + tp_iternext = Py_TYPE(iter)->tp_iternext; + if (tp_iternext == NULL) { + *item = NULL; + PyErr_Format(PyExc_TypeError, "expected an iterator, got '%s'", + Py_TYPE(iter)->tp_name); + return -1; + } + + if ((*item = tp_iternext(iter))) { + return 1; + } + if (!PyErr_Occurred()) { + return 0; + } + if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + PyErr_Clear(); + return 0; + } + return -1; +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline PyObject* PyLong_FromInt32(int32_t value) +{ + Py_BUILD_ASSERT(sizeof(long) >= 4); + return PyLong_FromLong(value); +} + +static inline PyObject* PyLong_FromInt64(int64_t value) +{ + Py_BUILD_ASSERT(sizeof(long long) >= 8); + return PyLong_FromLongLong(value); +} + +static inline PyObject* PyLong_FromUInt32(uint32_t value) +{ + Py_BUILD_ASSERT(sizeof(unsigned long) >= 4); + return PyLong_FromUnsignedLong(value); +} + +static inline PyObject* PyLong_FromUInt64(uint64_t value) +{ + Py_BUILD_ASSERT(sizeof(unsigned long long) >= 8); + return PyLong_FromUnsignedLongLong(value); +} + +static inline int PyLong_AsInt32(PyObject *obj, int32_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(int) == 4); + int value = PyLong_AsInt(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (int32_t)value; + return 0; +} + +static inline int PyLong_AsInt64(PyObject *obj, int64_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long long) == 8); + long long value = PyLong_AsLongLong(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (int64_t)value; + return 0; +} + +static inline int PyLong_AsUInt32(PyObject *obj, uint32_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long) >= 4); + unsigned long value = PyLong_AsUnsignedLong(obj); + if (value == (unsigned long)-1 && PyErr_Occurred()) { + return -1; + } +#if SIZEOF_LONG > 4 + if ((unsigned long)UINT32_MAX < value) { + PyErr_SetString(PyExc_OverflowError, + "Python int too large to convert to C uint32_t"); + return -1; + } +#endif + *pvalue = (uint32_t)value; + return 0; +} + +static inline int PyLong_AsUInt64(PyObject *obj, uint64_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long long) == 8); + unsigned long long value = PyLong_AsUnsignedLongLong(obj); + if (value == (unsigned long long)-1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (uint64_t)value; + return 0; +} +#endif + + +// gh-102471 added import and export API for integers to 3.14.0a2. +#if PY_VERSION_HEX < 0x030E00A2 && PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) +// Helpers to access PyLongObject internals. +static inline void +_PyLong_SetSignAndDigitCount(PyLongObject *op, int sign, Py_ssize_t size) +{ +#if PY_VERSION_HEX >= 0x030C0000 + op->long_value.lv_tag = (uintptr_t)(1 - sign) | ((uintptr_t)(size) << 3); +#elif PY_VERSION_HEX >= 0x030900A4 + Py_SET_SIZE(op, sign * size); +#else + Py_SIZE(op) = sign * size; +#endif +} + +static inline Py_ssize_t +_PyLong_DigitCount(const PyLongObject *op) +{ +#if PY_VERSION_HEX >= 0x030C0000 + return (Py_ssize_t)(op->long_value.lv_tag >> 3); +#else + return _PyLong_Sign((PyObject*)op) < 0 ? -Py_SIZE(op) : Py_SIZE(op); +#endif +} + +static inline digit* +_PyLong_GetDigits(const PyLongObject *op) +{ +#if PY_VERSION_HEX >= 0x030C0000 + return (digit*)(op->long_value.ob_digit); +#else + return (digit*)(op->ob_digit); +#endif +} + +typedef struct PyLongLayout { + uint8_t bits_per_digit; + uint8_t digit_size; + int8_t digits_order; + int8_t digit_endianness; +} PyLongLayout; + +typedef struct PyLongExport { + int64_t value; + uint8_t negative; + Py_ssize_t ndigits; + const void *digits; + Py_uintptr_t _reserved; +} PyLongExport; + +typedef struct PyLongWriter PyLongWriter; + +static inline const PyLongLayout* +PyLong_GetNativeLayout(void) +{ + static const PyLongLayout PyLong_LAYOUT = { + PyLong_SHIFT, + sizeof(digit), + -1, // least significant first + PY_LITTLE_ENDIAN ? -1 : 1, + }; + + return &PyLong_LAYOUT; +} + +static inline int +PyLong_Export(PyObject *obj, PyLongExport *export_long) +{ + if (!PyLong_Check(obj)) { + memset(export_long, 0, sizeof(*export_long)); + PyErr_Format(PyExc_TypeError, "expected int, got %s", + Py_TYPE(obj)->tp_name); + return -1; + } + + // Fast-path: try to convert to a int64_t + PyLongObject *self = (PyLongObject*)obj; + int overflow; +#if SIZEOF_LONG == 8 + long value = PyLong_AsLongAndOverflow(obj, &overflow); +#else + // Windows has 32-bit long, so use 64-bit long long instead + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); +#endif + Py_BUILD_ASSERT(sizeof(value) == sizeof(int64_t)); + // the function cannot fail since obj is a PyLongObject + assert(!(value == -1 && PyErr_Occurred())); + + if (!overflow) { + export_long->value = value; + export_long->negative = 0; + export_long->ndigits = 0; + export_long->digits = 0; + export_long->_reserved = 0; + } + else { + export_long->value = 0; + export_long->negative = _PyLong_Sign(obj) < 0; + export_long->ndigits = _PyLong_DigitCount(self); + if (export_long->ndigits == 0) { + export_long->ndigits = 1; + } + export_long->digits = _PyLong_GetDigits(self); + export_long->_reserved = (Py_uintptr_t)Py_NewRef(obj); + } + return 0; +} + +static inline void +PyLong_FreeExport(PyLongExport *export_long) +{ + PyObject *obj = (PyObject*)export_long->_reserved; + + if (obj) { + export_long->_reserved = 0; + Py_DECREF(obj); + } +} + +static inline PyLongWriter* +PyLongWriter_Create(int negative, Py_ssize_t ndigits, void **digits) +{ + if (ndigits <= 0) { + PyErr_SetString(PyExc_ValueError, "ndigits must be positive"); + return NULL; + } + assert(digits != NULL); + + PyLongObject *obj = _PyLong_New(ndigits); + if (obj == NULL) { + return NULL; + } + _PyLong_SetSignAndDigitCount(obj, negative?-1:1, ndigits); + + *digits = _PyLong_GetDigits(obj); + return (PyLongWriter*)obj; +} + +static inline void +PyLongWriter_Discard(PyLongWriter *writer) +{ + PyLongObject *obj = (PyLongObject *)writer; + + assert(Py_REFCNT(obj) == 1); + Py_DECREF(obj); +} + +static inline PyObject* +PyLongWriter_Finish(PyLongWriter *writer) +{ + PyObject *obj = (PyObject *)writer; + PyLongObject *self = (PyLongObject*)obj; + Py_ssize_t j = _PyLong_DigitCount(self); + Py_ssize_t i = j; + int sign = _PyLong_Sign(obj); + + assert(Py_REFCNT(obj) == 1); + + // Normalize and get singleton if possible + while (i > 0 && _PyLong_GetDigits(self)[i-1] == 0) { + --i; + } + if (i != j) { + if (i == 0) { + sign = 0; + } + _PyLong_SetSignAndDigitCount(self, sign, i); + } + if (i <= 1) { + long val = sign * (long)(_PyLong_GetDigits(self)[0]); + Py_DECREF(obj); + return PyLong_FromLong(val); + } + + return obj; +} +#endif + + +#if PY_VERSION_HEX < 0x030C00A3 +# define Py_T_SHORT T_SHORT +# define Py_T_INT T_INT +# define Py_T_LONG T_LONG +# define Py_T_FLOAT T_FLOAT +# define Py_T_DOUBLE T_DOUBLE +# define Py_T_STRING T_STRING +# define _Py_T_OBJECT T_OBJECT +# define Py_T_CHAR T_CHAR +# define Py_T_BYTE T_BYTE +# define Py_T_UBYTE T_UBYTE +# define Py_T_USHORT T_USHORT +# define Py_T_UINT T_UINT +# define Py_T_ULONG T_ULONG +# define Py_T_STRING_INPLACE T_STRING_INPLACE +# define Py_T_BOOL T_BOOL +# define Py_T_OBJECT_EX T_OBJECT_EX +# define Py_T_LONGLONG T_LONGLONG +# define Py_T_ULONGLONG T_ULONGLONG +# define Py_T_PYSSIZET T_PYSSIZET + +# if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) +# define _Py_T_NONE T_NONE +# endif + +# define Py_READONLY READONLY +# define Py_AUDIT_READ READ_RESTRICTED +# define _Py_WRITE_RESTRICTED PY_WRITE_RESTRICTED +#endif + + +// gh-127350 added Py_fopen() and Py_fclose() to Python 3.14a4 +#if PY_VERSION_HEX < 0x030E00A4 +static inline FILE* Py_fopen(PyObject *path, const char *mode) +{ +#if 0x030400A2 <= PY_VERSION_HEX && !defined(PYPY_VERSION) + PyAPI_FUNC(FILE*) _Py_fopen_obj(PyObject *path, const char *mode); + + return _Py_fopen_obj(path, mode); +#else + FILE *f; + PyObject *bytes; +#if PY_VERSION_HEX >= 0x03000000 + if (!PyUnicode_FSConverter(path, &bytes)) { + return NULL; + } +#else + if (!PyString_Check(path)) { + PyErr_SetString(PyExc_TypeError, "except str"); + return NULL; + } + bytes = Py_NewRef(path); +#endif + const char *path_bytes = PyBytes_AS_STRING(bytes); + + f = fopen(path_bytes, mode); + Py_DECREF(bytes); + + if (f == NULL) { + PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); + return NULL; + } + return f; +#endif +} + +static inline int Py_fclose(FILE *file) +{ + return fclose(file); +} +#endif + + +#if 0x03080000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030E0000 && !defined(PYPY_VERSION) +static inline PyObject* +PyConfig_Get(const char *name) +{ + typedef enum { + _PyConfig_MEMBER_INT, + _PyConfig_MEMBER_UINT, + _PyConfig_MEMBER_ULONG, + _PyConfig_MEMBER_BOOL, + _PyConfig_MEMBER_WSTR, + _PyConfig_MEMBER_WSTR_OPT, + _PyConfig_MEMBER_WSTR_LIST, + } PyConfigMemberType; + + typedef struct { + const char *name; + size_t offset; + PyConfigMemberType type; + const char *sys_attr; + } PyConfigSpec; + +#define PYTHONCAPI_COMPAT_SPEC(MEMBER, TYPE, sys_attr) \ + {#MEMBER, offsetof(PyConfig, MEMBER), \ + _PyConfig_MEMBER_##TYPE, sys_attr} + + static const PyConfigSpec config_spec[] = { + PYTHONCAPI_COMPAT_SPEC(argv, WSTR_LIST, "argv"), + PYTHONCAPI_COMPAT_SPEC(base_exec_prefix, WSTR_OPT, "base_exec_prefix"), + PYTHONCAPI_COMPAT_SPEC(base_executable, WSTR_OPT, "_base_executable"), + PYTHONCAPI_COMPAT_SPEC(base_prefix, WSTR_OPT, "base_prefix"), + PYTHONCAPI_COMPAT_SPEC(bytes_warning, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(exec_prefix, WSTR_OPT, "exec_prefix"), + PYTHONCAPI_COMPAT_SPEC(executable, WSTR_OPT, "executable"), + PYTHONCAPI_COMPAT_SPEC(inspect, BOOL, _Py_NULL), +#if 0x030C0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(int_max_str_digits, UINT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(interactive, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(module_search_paths, WSTR_LIST, "path"), + PYTHONCAPI_COMPAT_SPEC(optimization_level, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(parser_debug, BOOL, _Py_NULL), +#if 0x03090000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(platlibdir, WSTR, "platlibdir"), +#endif + PYTHONCAPI_COMPAT_SPEC(prefix, WSTR_OPT, "prefix"), + PYTHONCAPI_COMPAT_SPEC(pycache_prefix, WSTR_OPT, "pycache_prefix"), + PYTHONCAPI_COMPAT_SPEC(quiet, BOOL, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(stdlib_dir, WSTR_OPT, "_stdlib_dir"), +#endif + PYTHONCAPI_COMPAT_SPEC(use_environment, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(verbose, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(warnoptions, WSTR_LIST, "warnoptions"), + PYTHONCAPI_COMPAT_SPEC(write_bytecode, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(xoptions, WSTR_LIST, "_xoptions"), + PYTHONCAPI_COMPAT_SPEC(buffered_stdio, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(check_hash_pycs_mode, WSTR, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(code_debug_ranges, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(configure_c_stdio, BOOL, _Py_NULL), +#if 0x030D0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(cpu_count, INT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(dev_mode, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(dump_refs, BOOL, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(dump_refs_file, WSTR_OPT, _Py_NULL), +#endif +#ifdef Py_GIL_DISABLED + PYTHONCAPI_COMPAT_SPEC(enable_gil, INT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(faulthandler, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(filesystem_encoding, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(filesystem_errors, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(hash_seed, ULONG, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(home, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(import_time, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(install_signal_handlers, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(isolated, BOOL, _Py_NULL), +#ifdef MS_WINDOWS + PYTHONCAPI_COMPAT_SPEC(legacy_windows_stdio, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(malloc_stats, BOOL, _Py_NULL), +#if 0x030A0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(orig_argv, WSTR_LIST, "orig_argv"), +#endif + PYTHONCAPI_COMPAT_SPEC(parse_argv, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(pathconfig_warnings, BOOL, _Py_NULL), +#if 0x030C0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(perf_profiling, UINT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(program_name, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_command, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_filename, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_module, WSTR_OPT, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(safe_path, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(show_ref_count, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(site_import, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(skip_source_first_line, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(stdio_encoding, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(stdio_errors, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(tracemalloc, UINT, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(use_frozen_modules, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(use_hash_seed, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(user_site_directory, BOOL, _Py_NULL), +#if 0x030A0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(warn_default_encoding, BOOL, _Py_NULL), +#endif + }; + +#undef PYTHONCAPI_COMPAT_SPEC + + const PyConfigSpec *spec; + int found = 0; + for (size_t i=0; i < sizeof(config_spec) / sizeof(config_spec[0]); i++) { + spec = &config_spec[i]; + if (strcmp(spec->name, name) == 0) { + found = 1; + break; + } + } + if (found) { + if (spec->sys_attr != NULL) { + PyObject *value = PySys_GetObject(spec->sys_attr); + if (value == NULL) { + PyErr_Format(PyExc_RuntimeError, "lost sys.%s", spec->sys_attr); + return NULL; + } + return Py_NewRef(value); + } + + PyAPI_FUNC(const PyConfig*) _Py_GetConfig(void); + + const PyConfig *config = _Py_GetConfig(); + void *member = (char *)config + spec->offset; + switch (spec->type) { + case _PyConfig_MEMBER_INT: + case _PyConfig_MEMBER_UINT: + { + int value = *(int *)member; + return PyLong_FromLong(value); + } + case _PyConfig_MEMBER_BOOL: + { + int value = *(int *)member; + return PyBool_FromLong(value != 0); + } + case _PyConfig_MEMBER_ULONG: + { + unsigned long value = *(unsigned long *)member; + return PyLong_FromUnsignedLong(value); + } + case _PyConfig_MEMBER_WSTR: + case _PyConfig_MEMBER_WSTR_OPT: + { + wchar_t *wstr = *(wchar_t **)member; + if (wstr != NULL) { + return PyUnicode_FromWideChar(wstr, -1); + } + else { + return Py_NewRef(Py_None); + } + } + case _PyConfig_MEMBER_WSTR_LIST: + { + const PyWideStringList *list = (const PyWideStringList *)member; + PyObject *tuple = PyTuple_New(list->length); + if (tuple == NULL) { + return NULL; + } + + for (Py_ssize_t i = 0; i < list->length; i++) { + PyObject *item = PyUnicode_FromWideChar(list->items[i], -1); + if (item == NULL) { + Py_DECREF(tuple); + return NULL; + } + PyTuple_SET_ITEM(tuple, i, item); + } + return tuple; + } + default: + Py_UNREACHABLE(); + } + } + + PyErr_Format(PyExc_ValueError, "unknown config option name: %s", name); + return NULL; +} + +static inline int +PyConfig_GetInt(const char *name, int *value) +{ + PyObject *obj = PyConfig_Get(name); + if (obj == NULL) { + return -1; + } + + if (!PyLong_Check(obj)) { + Py_DECREF(obj); + PyErr_Format(PyExc_TypeError, "config option %s is not an int", name); + return -1; + } + + int as_int = PyLong_AsInt(obj); + Py_DECREF(obj); + if (as_int == -1 && PyErr_Occurred()) { + PyErr_Format(PyExc_OverflowError, + "config option %s value does not fit into a C int", name); + return -1; + } + + *value = as_int; + return 0; +} +#endif // PY_VERSION_HEX > 0x03090000 && !defined(PYPY_VERSION) + +// gh-133144 added PyUnstable_Object_IsUniquelyReferenced() to Python 3.14.0b1. +// Adapted from _PyObject_IsUniquelyReferenced() implementation. +#if PY_VERSION_HEX < 0x030E00B0 +static inline int PyUnstable_Object_IsUniquelyReferenced(PyObject *obj) +{ +#if !defined(Py_GIL_DISABLED) + return Py_REFCNT(obj) == 1; +#else + // NOTE: the entire ob_ref_shared field must be zero, including flags, to + // ensure that other threads cannot concurrently create new references to + // this object. + return (_Py_IsOwnedByCurrentThread(obj) && + _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local) == 1 && + _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared) == 0); +#endif +} +#endif + +// gh-128926 added PyUnstable_TryIncRef() and PyUnstable_EnableTryIncRef() to +// Python 3.14.0a5. Adapted from _Py_TryIncref() and _PyObject_SetMaybeWeakref(). +#if PY_VERSION_HEX < 0x030E00A5 +static inline int PyUnstable_TryIncRef(PyObject *op) +{ +#ifndef Py_GIL_DISABLED + if (Py_REFCNT(op) > 0) { + Py_INCREF(op); + return 1; + } + return 0; +#else + // _Py_TryIncrefFast() + uint32_t local = _Py_atomic_load_uint32_relaxed(&op->ob_ref_local); + local += 1; + if (local == 0) { + // immortal + return 1; + } + if (_Py_IsOwnedByCurrentThread(op)) { + _Py_INCREF_STAT_INC(); + _Py_atomic_store_uint32_relaxed(&op->ob_ref_local, local); +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return 1; + } + + // _Py_TryIncRefShared() + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); + for (;;) { + // If the shared refcount is zero and the object is either merged + // or may not have weak references, then we cannot incref it. + if (shared == 0 || shared == _Py_REF_MERGED) { + return 0; + } + + if (_Py_atomic_compare_exchange_ssize( + &op->ob_ref_shared, + &shared, + shared + (1 << _Py_REF_SHARED_SHIFT))) { +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + _Py_INCREF_STAT_INC(); + return 1; + } + } +#endif +} + +static inline void PyUnstable_EnableTryIncRef(PyObject *op) +{ +#ifdef Py_GIL_DISABLED + // _PyObject_SetMaybeWeakref() + if (_Py_IsImmortal(op)) { + return; + } + for (;;) { + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); + if ((shared & _Py_REF_SHARED_FLAG_MASK) != 0) { + // Nothing to do if it's in WEAKREFS, QUEUED, or MERGED states. + return; + } + if (_Py_atomic_compare_exchange_ssize( + &op->ob_ref_shared, &shared, shared | _Py_REF_MAYBE_WEAKREF)) { + return; + } + } +#else + (void)op; // unused argument +#endif +} +#endif + + +#if PY_VERSION_HEX < 0x030F0000 +static inline PyObject* +PySys_GetAttrString(const char *name) +{ +#if PY_VERSION_HEX >= 0x03000000 + PyObject *value = Py_XNewRef(PySys_GetObject(name)); +#else + PyObject *value = Py_XNewRef(PySys_GetObject((char*)name)); +#endif + if (value != NULL) { + return value; + } + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, "lost sys.%s", name); + } + return NULL; +} + +static inline PyObject* +PySys_GetAttr(PyObject *name) +{ +#if PY_VERSION_HEX >= 0x03000000 + const char *name_str = PyUnicode_AsUTF8(name); +#else + const char *name_str = PyString_AsString(name); +#endif + if (name_str == NULL) { + return NULL; + } + + return PySys_GetAttrString(name_str); +} + +static inline int +PySys_GetOptionalAttrString(const char *name, PyObject **value) +{ +#if PY_VERSION_HEX >= 0x03000000 + *value = Py_XNewRef(PySys_GetObject(name)); +#else + *value = Py_XNewRef(PySys_GetObject((char*)name)); +#endif + if (*value != NULL) { + return 1; + } + return 0; +} + +static inline int +PySys_GetOptionalAttr(PyObject *name, PyObject **value) +{ +#if PY_VERSION_HEX >= 0x03000000 + const char *name_str = PyUnicode_AsUTF8(name); +#else + const char *name_str = PyString_AsString(name); +#endif + if (name_str == NULL) { + *value = NULL; + return -1; + } + + return PySys_GetOptionalAttrString(name_str, value); +} +#endif // PY_VERSION_HEX < 0x030F00A1 + + +#if PY_VERSION_HEX < 0x030F00A1 +typedef struct PyBytesWriter { + char small_buffer[256]; + PyObject *obj; + Py_ssize_t size; +} PyBytesWriter; + +static inline Py_ssize_t +_PyBytesWriter_GetAllocated(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return sizeof(writer->small_buffer); + } + else { + return PyBytes_GET_SIZE(writer->obj); + } +} + + +static inline int +_PyBytesWriter_Resize_impl(PyBytesWriter *writer, Py_ssize_t size, + int resize) +{ + int overallocate = resize; + assert(size >= 0); + + if (size <= _PyBytesWriter_GetAllocated(writer)) { + return 0; + } + + if (overallocate) { +#ifdef MS_WINDOWS + /* On Windows, overallocate by 50% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 2)) { + size += size / 2; + } +#else + /* On Linux, overallocate by 25% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 4)) { + size += size / 4; + } +#endif + } + + if (writer->obj != NULL) { + if (_PyBytes_Resize(&writer->obj, size)) { + return -1; + } + assert(writer->obj != NULL); + } + else { + writer->obj = PyBytes_FromStringAndSize(NULL, size); + if (writer->obj == NULL) { + return -1; + } + + if (resize) { + assert((size_t)size > sizeof(writer->small_buffer)); + memcpy(PyBytes_AS_STRING(writer->obj), + writer->small_buffer, + sizeof(writer->small_buffer)); + } + } + return 0; +} + +static inline void* +PyBytesWriter_GetData(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return writer->small_buffer; + } + else { + return PyBytes_AS_STRING(writer->obj); + } +} + +static inline Py_ssize_t +PyBytesWriter_GetSize(PyBytesWriter *writer) +{ + return writer->size; +} + +static inline void +PyBytesWriter_Discard(PyBytesWriter *writer) +{ + if (writer == NULL) { + return; + } + + Py_XDECREF(writer->obj); + PyMem_Free(writer); +} + +static inline PyBytesWriter* +PyBytesWriter_Create(Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return NULL; + } + + PyBytesWriter *writer = (PyBytesWriter*)PyMem_Malloc(sizeof(PyBytesWriter)); + if (writer == NULL) { + PyErr_NoMemory(); + return NULL; + } + + writer->obj = NULL; + writer->size = 0; + + if (size >= 1) { + if (_PyBytesWriter_Resize_impl(writer, size, 0) < 0) { + PyBytesWriter_Discard(writer); + return NULL; + } + writer->size = size; + } + return writer; +} + +static inline PyObject* +PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size) +{ + PyObject *result; + if (size == 0) { + result = PyBytes_FromStringAndSize("", 0); + } + else if (writer->obj != NULL) { + if (size != PyBytes_GET_SIZE(writer->obj)) { + if (_PyBytes_Resize(&writer->obj, size)) { + goto error; + } + } + result = writer->obj; + writer->obj = NULL; + } + else { + result = PyBytes_FromStringAndSize(writer->small_buffer, size); + } + PyBytesWriter_Discard(writer); + return result; + +error: + PyBytesWriter_Discard(writer); + return NULL; +} + +static inline PyObject* +PyBytesWriter_Finish(PyBytesWriter *writer) +{ + return PyBytesWriter_FinishWithSize(writer, writer->size); +} + +static inline PyObject* +PyBytesWriter_FinishWithPointer(PyBytesWriter *writer, void *buf) +{ + Py_ssize_t size = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (size < 0 || size > _PyBytesWriter_GetAllocated(writer)) { + PyBytesWriter_Discard(writer); + PyErr_SetString(PyExc_ValueError, "invalid end pointer"); + return NULL; + } + + return PyBytesWriter_FinishWithSize(writer, size); +} + +static inline int +PyBytesWriter_Resize(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return -1; + } + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline int +PyBytesWriter_Grow(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0 && writer->size + size < 0) { + PyErr_SetString(PyExc_ValueError, "invalid size"); + return -1; + } + if (size > PY_SSIZE_T_MAX - writer->size) { + PyErr_NoMemory(); + return -1; + } + size = writer->size + size; + + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline void* +PyBytesWriter_GrowAndUpdatePointer(PyBytesWriter *writer, + Py_ssize_t size, void *buf) +{ + Py_ssize_t pos = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (PyBytesWriter_Grow(writer, size) < 0) { + return NULL; + } + return (char*)PyBytesWriter_GetData(writer) + pos; +} + +static inline int +PyBytesWriter_WriteBytes(PyBytesWriter *writer, + const void *bytes, Py_ssize_t size) +{ + if (size < 0) { + size_t len = strlen((const char*)bytes); + if (len > (size_t)PY_SSIZE_T_MAX) { + PyErr_NoMemory(); + return -1; + } + size = (Py_ssize_t)len; + } + + Py_ssize_t pos = writer->size; + if (PyBytesWriter_Grow(writer, size) < 0) { + return -1; + } + char *buf = (char*)PyBytesWriter_GetData(writer); + memcpy(buf + pos, bytes, (size_t)size); + return 0; +} + +static inline int +PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) + Py_GCC_ATTRIBUTE((format(printf, 2, 3))); + +static inline int +PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) +{ + va_list vargs; + va_start(vargs, format); + PyObject *str = PyBytes_FromFormatV(format, vargs); + va_end(vargs); + + if (str == NULL) { + return -1; + } + int res = PyBytesWriter_WriteBytes(writer, + PyBytes_AS_STRING(str), + PyBytes_GET_SIZE(str)); + Py_DECREF(str); + return res; +} +#endif // PY_VERSION_HEX < 0x030F00A1 + + +#if PY_VERSION_HEX < 0x030F00A1 +static inline PyObject* +PyTuple_FromArray(PyObject *const *array, Py_ssize_t size) +{ + PyObject *tuple = PyTuple_New(size); + if (tuple == NULL) { + return NULL; + } + for (Py_ssize_t i=0; i < size; i++) { + PyObject *item = array[i]; + PyTuple_SET_ITEM(tuple, i, Py_NewRef(item)); + } + return tuple; +} +#endif + + +#if PY_VERSION_HEX < 0x030F00A1 +static inline Py_hash_t +PyUnstable_Unicode_GET_CACHED_HASH(PyObject *op) +{ +#ifdef PYPY_VERSION + (void)op; // unused argument + return -1; +#elif PY_VERSION_HEX >= 0x03000000 + return ((PyASCIIObject*)op)->hash; +#else + return ((PyUnicodeObject*)op)->hash; +#endif +} +#endif + #ifdef __cplusplus } From c5593e75b31286dc5358a451ae6e8c61e47921d6 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 6 Nov 2025 21:39:44 +0000 Subject: [PATCH 154/651] Fix flaky memory profiler test (#167168) Fixes #167037 Do not check the exact number of frames. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167168 Approved by: https://github.com/angelayi --- test/test_cuda.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index dfbcdc1b40401..1d3ff12b4b6ea 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -7508,6 +7508,8 @@ def forward(self, x): device = "cuda" mod = MLPModule(device) with tempfile.TemporaryDirectory() as tmpdir: + # reset cache to start fresh + torch.cuda.memory.empty_cache() torch.cuda.memory._record_memory_history() compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) result = compiled(torch.randn(10, 10, device=device)) @@ -7518,10 +7520,7 @@ def forward(self, x): torch.cuda.empty_cache() fx_frames = self.collect_frames(augmented_snapshot) - if TEST_WITH_ROCM: - self.assertGreater(len(fx_frames), 0) - else: - self.assertEqual(len(fx_frames), 12) + self.assertGreater(len(fx_frames), 2) for frame in fx_frames: # Every FX frame should have both node_op and node_name From a45a17f65ed21232bc702e59c66fcad6be69ff73 Mon Sep 17 00:00:00 2001 From: Yun Wu Date: Thu, 6 Nov 2025 22:33:22 +0000 Subject: [PATCH 155/651] Fix boxcox to return same result for same input in one batch (#166986) Summary: The SIMD path is using SLEEF version of pow which is slightly different from std::pow. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results. Deploy: Need to make a hotfix in waas to monitor release signals, since this diff can cause testing failures in veloski and waas release correctness tests. Test Plan: Sandcastle. Differential Revision: D86218207 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166986 Approved by: https://github.com/swolchok --- caffe2/perfkernels/batch_box_cox_vec.h | 34 +++++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/caffe2/perfkernels/batch_box_cox_vec.h b/caffe2/perfkernels/batch_box_cox_vec.h index ed2e83062d107..08e4f84fe4327 100644 --- a/caffe2/perfkernels/batch_box_cox_vec.h +++ b/caffe2/perfkernels/batch_box_cox_vec.h @@ -73,6 +73,19 @@ void box_cox_zero_lambda( } } +template +at::vec::Vectorized box_cox_nonzero_lambda_impl( + at::vec::Vectorized data, + at::vec::Vectorized lambda1, + at::vec::Vectorized lambda2, + at::vec::Vectorized k_eps) { + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); + auto pow = max.pow(lambda1); + return at::vec::fmsub(pow, lambda_over_1, lambda_over_1); +} + template void box_cox_nonzero_lambda( int64_t D, @@ -88,21 +101,18 @@ void box_cox_nonzero_lambda( auto k_eps_vec = Vec(k_eps); for(; j + VLEN < D; j += VLEN) { auto data = Vec::loadu(data_ptr + j); - auto lambda2 = Vec::loadu(lambda2_ptr + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); - auto pow = max.pow(lambda1); - auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); + auto lambda2 = Vec::loadu(lambda2_ptr + j); + auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); res.store(out + j); } - for ( ;j < D; ++j) { - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); - auto pow = std::pow(max, lambda1_ptr[j]); - out[j] = pow * lambda_over_1 - lambda_over_1; + if (j < D) { + auto remaining = D - j; + auto data = Vec::loadu(data_ptr + j, remaining); + auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining); + auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining); + auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); + res.store(out + j, remaining); } } #else From 9b4ac45d2fd7e4bde216efa18d1e8e4bce33e4ae Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Nov 2025 22:34:46 +0000 Subject: [PATCH 156/651] Revert "[Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165)" This reverts commit eefa16342c9f322b56c7c0cd6d309c3ed8f0b882. Reverted https://github.com/pytorch/pytorch/pull/166165 on behalf of https://github.com/jeanschmidt due to Breaking internal tests D86216934 ([comment](https://github.com/pytorch/pytorch/pull/166165#issuecomment-3499645688)) --- test/inductor/test_padding.py | 7 +-- test/inductor/test_torchinductor.py | 4 +- torch/_inductor/fx_passes/post_grad.py | 8 ++-- torch/_inductor/utils.py | 64 -------------------------- 4 files changed, 6 insertions(+), 77 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 5e599110d29d6..c67bde87a369b 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -500,13 +500,8 @@ def test_LinearAndSoftmax_codegen(self, bias=True): forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned - if bias: - # addmm -> mm + bias and bias is fused with softmax - softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)" - else: - softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)" self.assertTrue( - softmax_load_str in forward_wrapper, + "tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper, f"forward_wrapper: {forward_wrapper}", ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d0ff5799ac417..fe9fa5a5e3a4c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15310,7 +15310,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_addmm_native_layer_norm", + "triton_poi_fused_native_layer_norm_relu", (torch.randn(4, 4, device=GPU_TYPE),), ), ] @@ -15323,7 +15323,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_Linear_ReLU", + "triton_poi_fused_LayerNorm_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 91b4e10bf7238..9808c6944e13c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -52,8 +52,8 @@ decode_device, get_all_devices, get_gpu_type, - has_uses_tagged_as, is_gpu, + is_pointwise_use, OPTIMUS_EXCLUDE_POST_GRAD, ) from ..virtualized import V @@ -1511,10 +1511,8 @@ def should_prefer_unfused_addmm(match): if not is_gpu(inp.meta["val"].device.type): return False - return has_uses_tagged_as( - match.output_node(), - (torch.Tag.pointwise, torch.Tag.reduction), - ) + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) @register_graph_pattern( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index cd0f3643d37f7..1a43e938d7146 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -549,70 +549,6 @@ def is_pointwise_use( return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) -class LogicalConnective(enum.Enum): - OR = enum.auto() - AND = enum.auto() - - -def has_uses( - target: Node, - use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, - use_aggregate_type: LogicalConnective = LogicalConnective.OR, -) -> bool: - """ - Given a target, explore the uses of `target` by applying `use_selector_fn` - on them, and then aggregate these booleans with the `use_aggregate_type` - logical connective. - - Uses in view ops will follow the views uses. - """ - - def get_use_aggregate_fn( - use_aggregate_type: LogicalConnective, - ) -> Callable[[Iterator[Any]], bool]: - match use_aggregate_type: - case LogicalConnective.AND: - return all - case LogicalConnective.OR: - return any - case _: - return any - - use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type) - - def has_uses_impl(use: Node) -> bool: - if use.op != "call_function": - return False - if not ( - isinstance(use.target, torch._ops.OpOverload) - or use.target is operator.getitem - ): - return False - - target = cast(torch._ops.OpOverload, use.target) - # Process getitem and view - if target is operator.getitem or is_view(target): - return use_aggregate_fn(has_uses_impl(user) for user in use.users) - - return use_selector_fn(target) - - return use_aggregate_fn(has_uses_impl(user) for user in target.users) - - -def has_uses_tagged_as( - target: Node, - use_tags: Collection[torch.Tag], - use_aggregate_type: LogicalConnective = LogicalConnective.OR, -) -> bool: - """ - Is there a use with given tags? - """ - - return has_uses( - target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type - ) - - def gen_gm_and_inputs( target: Any, args: list[Any], kwargs: dict[str, Any] ) -> tuple[GraphModule, list[torch.Tensor]]: From 2073af579035ec97ad6811597accd24f13a9a7a7 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 23:08:08 -0800 Subject: [PATCH 157/651] [user-streams] Refactor user object index in streams (#167175) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167175 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_streams.py | 12 ++++++------ torch/_dynamo/variables/builder.py | 12 ++++++------ torch/_dynamo/variables/streams.py | 4 ++-- torch/_dynamo/variables/user_defined.py | 1 - 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 1b81597977d77..51e5ed7747504 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -74,13 +74,13 @@ def fn(x, y, s1, s2): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': None} + # Annotation: {'stream': 0} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': None} + # Annotation: {'stream': 1} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': None} + # Annotation: {'stream': 1} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None return (add_3,) @@ -229,13 +229,13 @@ def fn(x, y, s0, s1, s2): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': None} + # Annotation: {'stream': 1} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': None} + # Annotation: {'stream': 2} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': None} + # Annotation: {'stream': 1} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1, add_2) """, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 81baaa236b0a8..e436a07bd0dcb 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1061,9 +1061,7 @@ def build_key_value(i, k, v): ) set_example_value(stream_proxy.node, value) var = StreamVariable( - stream_proxy, - value, - source=self.source, + stream_proxy, value, source=self.source, user_object_index=index ) return self.tx.output.side_effects.track_object_existing(value, var) elif isinstance(value, (torch._C._SDPAParams)): @@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe return SymNodeVariable(proxy, example_value, **options) elif ( isinstance(example_value, torch.Stream) - and proxy.node.target - in (get_external_object_by_index, torch.accelerator.current_stream) + and proxy.node.target == get_external_object_by_index ) or proxy.node.target in [ device_interface.current_stream for _, device_interface in get_registered_device_interfaces() ]: set_example_value(proxy.node, example_value) - return StreamVariable(proxy, example_value, **options) + index = None + if proxy.node.target == get_external_object_by_index: + index = proxy.node.args[0] + return StreamVariable(proxy, example_value, index, **options) elif ( inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, torch.Event) diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 65b4add4232f6..6aa6e43a2a00e 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -204,11 +204,11 @@ def __init__( self, proxy: Proxy, value: torch.Stream, + user_object_index: Optional[int] = None, **kwargs: Any, ) -> None: # Index into the user object table # used to pass arbitrary objects to the graph - user_object_index = kwargs.pop("user_obj_index", None) if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value @@ -300,7 +300,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.append_output(codegen.create_load_const(self.user_object_index)) codegen.extend_output(create_call_function(1, False)) else: - # TODO mlazos: evaluate if we still need this + # This will support the legacy behavior prefix = f"_stream_{self.device}" name = codegen.tx.output.install_global_by_id(prefix, self.value) codegen.append_output(codegen.create_load_global(name, add=True)) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 9dd154dacbb9e..cea5be48a6b30 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -838,7 +838,6 @@ def deque_signature(iterable=None, maxlen=None): proxy=tx.output.create_proxy( "call_function", get_external_object_by_index, (ind,), {} ), - user_obj_index=ind, ) else: tensor_variable = wrap_fx_proxy( From 0b0610941222ea5449e20d7d2507422beac64baa Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 23:08:08 -0800 Subject: [PATCH 158/651] [user-streams] Fix bug in object bytecode construction (#167176) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167176 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167175 --- torch/_dynamo/output_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 50a2667c12a25..f393b4a269d89 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1542,7 +1542,7 @@ def compile_subgraph( ) ) tmp_vars = [] - for constructor in reversed(index_to_bytecode_constructor.values()): + for constructor in index_to_bytecode_constructor.values(): constructor(codegen) var_name = ( self.new_var() From 106d34c80a84b903f151b180ed771bd26ec50679 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 23:08:09 -0800 Subject: [PATCH 159/651] [user-streams] add requires cuda decorator (#167180) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167180 Approved by: https://github.com/donigian, https://github.com/Lucaskabela, https://github.com/Skylion007 ghstack dependencies: #167175, #167176 --- test/dynamo/test_streams.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 51e5ed7747504..105d195e6ac3e 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -380,6 +380,7 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): """, ) + @requires_cuda def test_stream_backward(self) -> None: def fn(x, y): s2 = torch.Stream() From 4b9ba0fb261f459789ca680ac0d7338b1072dde5 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 23:08:09 -0800 Subject: [PATCH 160/651] [user-streams] Add requires cuda to all test cases (#167195) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167195 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167175, #167176, #167180 --- test/dynamo/test_streams.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 105d195e6ac3e..b736a5750e3a6 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -196,6 +196,7 @@ def fn(x, s0, s1): s_exp = fn(*inp) self.assertEqual(s_act, s_exp) + @requires_cuda def test_nested_stream_enter_exit(self): def fn(x, y, s0, s1, s2): with s1: @@ -249,6 +250,7 @@ def test_stream_enter_exit_graph_break(self): def test_nested_stream_enter_exit_graph_break(self): pass + @requires_cuda def test_local_stream_enter_exit(self): def fn(x, y): s2 = torch.Stream() @@ -289,6 +291,7 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): """, ) + @requires_cuda def test_local_stream_nested_enter_exit(self): def fn(x, y): s2 = torch.Stream() @@ -331,6 +334,7 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): """, ) + @requires_cuda def test_stream_with_mutation(self): def fn(x, y): s2 = torch.Stream() From 2923b02c6ed0e5fbab0d98728a2c65d66c420cd9 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 5 Nov 2025 15:02:38 -0800 Subject: [PATCH 161/651] [DTensor] add explicit mode (ExplicitRedistributionContext) (#166593) usage: ``` dx = distribute_tensor(x, device_mesh, [Shard(0)]) dA = distribute_tensor(A, device_mesh, [Shard(0)]) with ExplicitRedistributionContext(): with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"): # Shard(0) @ Shard(0) requires a redistribution torch.matmul(dx, dA) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166593 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_utils.py | 98 ++++++++++++++++++++++++++- torch/distributed/tensor/_dispatch.py | 9 ++- torch/distributed/tensor/_utils.py | 27 ++++++++ 3 files changed, 132 insertions(+), 2 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 01f150f090b73..09a6ca817a75b 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -1,11 +1,18 @@ # Owner(s): ["oncall: distributed"] import itertools +from contextlib import nullcontext from typing import Any import torch +import torch.distributed as dist +from torch.distributed._local_tensor import ( + local_tensor_mode, + LocalTensor, + LocalTensorMode, +) from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import distribute_tensor, DTensor +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor._utils import ( _compute_local_shape_and_global_offset, @@ -14,6 +21,7 @@ compute_global_tensor_shape, compute_local_shape_and_global_offset, compute_local_tensor_info, + ExplicitRedistributionContext, ) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import ( @@ -851,5 +859,93 @@ def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self): self.assertEqual(global_tensor, dtensor_2d.full_tensor()) +class LocalTensorTestBase(TestCase): + def assertEqual(self, lhs, rhs, **kwargs): + mode = local_tensor_mode() + with nullcontext() if mode is None else mode.disable(): + if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor): + assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor) + super().assertEqual(lhs._ranks, rhs._ranks) + for r in lhs._ranks: + super().assertEqual( + lhs._local_tensors[r], + rhs._local_tensors[r], + lambda m: f"rank {r}: {m}", + ) + elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor): + lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs) + for r in lhs._ranks: + super().assertEqual( + lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}" + ) + else: + return super().assertEqual(lhs, rhs, **kwargs) + + @property + def world_size(self): + raise NotImplementedError("override world-size in your subclass") + + def build_device_mesh(self) -> DeviceMesh: + return init_device_mesh("cpu", (self.world_size,)) + + def setUp(self): + super().setUp() + torch.distributed.init_process_group( + # TODO: test other ranks too + "fake", + rank=0, + world_size=self.world_size, + ) + + def tearDown(self): + super().tearDown() + try: + dist.destroy_process_group() + except AssertionError: + pass + + +class TestExplicitRedistribute(LocalTensorTestBase): + @property + def world_size(self): + return 4 + + def test_explicit_matmul(self): + with LocalTensorMode(self.world_size): + device_mesh = self.build_device_mesh() + dim = 128 + x = torch.randn(8, dim, requires_grad=True) + A = torch.randn(dim, dim, requires_grad=True) + + # Prepare DTensors + dx = distribute_tensor(x, device_mesh, [Shard(0)]) + dA = distribute_tensor(A, device_mesh, [Shard(0)]) + + # implicit redistribute works as usual by default + with CommDebugMode() as comm_mode: + torch.matmul(dx, dA) + self.assertEqual(comm_mode.get_total_counts(), 1) + + # explicit redistribute works too + with ExplicitRedistributionContext(): + with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"): + torch.matmul(dx, dA) + + # explicit redistribute allows manual redistribute + with ExplicitRedistributionContext(): + dA_repl = dA.redistribute(device_mesh, [Replicate()]) + torch.matmul(dx, dA_repl) + + dx = distribute_tensor(x, device_mesh, [Shard(0)]) + dA = distribute_tensor(A, device_mesh, [Replicate()]) + with ExplicitRedistributionContext(): + dY = torch.matmul(dx, dA_repl) + loss = dY.sum() + + # we now see the error during backwards + with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"): + loss.backward() + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 1800edbfdb344..27c9dd550d726 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -20,7 +20,10 @@ convolution_backward_handler, convolution_handler, ) -from torch.distributed.tensor._utils import try_find_mesh_from_args +from torch.distributed.tensor._utils import ( + ExplicitRedistributionContext, + try_find_mesh_from_args, +) from torch.distributed.tensor.placement_types import Partial, Placement, Replicate from torch.utils._debug_mode import get_active_debug_mode from torch.utils._python_dispatch import return_and_correct_aliasing @@ -199,6 +202,10 @@ def dispatch( if participating: # computation that happens in the current rank of the mesh, normal case if output_sharding.needs_redistribute: + if ExplicitRedistributionContext.is_active(): + raise RuntimeError( + f"Implicit redistribution occurred while ExplicitRedistributionContext was active for {op_info.schema}" + ) # If sharding propagation decision needs redistribute, perform redistribute # on args first, which could potentially modify args (i.e. allgather certain arg) assert output_sharding.redistribute_schema is not None diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index d192ddf7c35b3..66040e8f24a2e 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -18,6 +18,33 @@ from torch.utils._typing_utils import not_none +class ExplicitRedistributionContext: + """ + Within this context manager, DTensor will refuse to perform implicit redistribution, + instead raising an error. Manual calls to ``redistribute()`` are required wherever a redistribution + must occur to avoid erroring. This can be used to ensure that the user is aware of all redistribution. + + Note: it is easier to use this mode on just the forward pass of a typical DTensor program, as the backwards pass + may contain implicit redistribution calls that are not visible to the user and difficult to replace with manual + calls. Redistribution during backward can be made explicit by writing `autograd.Function`s that are no-op + during forward and perform a manual redistribution during backwards. + """ + + _explicit_redistribute_mode = False + + @classmethod + def is_active(cls) -> bool: + return cls._explicit_redistribute_mode + + def __enter__(self): + self.prev = ExplicitRedistributionContext._explicit_redistribute_mode + ExplicitRedistributionContext._explicit_redistribute_mode = True + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + ExplicitRedistributionContext._explicit_redistribute_mode = self.prev + + def _explicit_order_placements( mesh_shape: ShapeType, placements: Sequence[Placement] ) -> Sequence[tuple[int, Placement]]: From f47cadf75d5fd5b78d9944d3f3a2b16892e90a73 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 7 Nov 2025 00:15:37 +0000 Subject: [PATCH 162/651] [BE][Typing][Dynamo] Type torch/_dynamo/variables/lists.py (#167156) Provides type coverage to torch/_dynamo/variables/dicts.py Coverage report: `mypy torch/_dynamo/variables/lists.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1759 lines and 102 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167156 Approved by: https://github.com/Skylion007, https://github.com/rtimpe --- torch/_dynamo/symbolic_convert.py | 6 +- torch/_dynamo/variables/ctx_manager.py | 6 +- torch/_dynamo/variables/iter.py | 12 +- torch/_dynamo/variables/lists.py | 346 ++++++++++++++----------- 4 files changed, 214 insertions(+), 156 deletions(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 3943f90b0020a..83e3edf5d8d6d 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3228,7 +3228,7 @@ def BUILD_TUPLE(self, inst: Instruction) -> None: def BUILD_SLICE(self, inst: Instruction) -> None: items = self.popn(inst.argval) - self.push(SliceVariable(items, tx=self)) + self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type] def BUILD_LIST(self, inst: Instruction) -> None: items = self.popn(inst.argval) @@ -3607,7 +3607,7 @@ def LIST_EXTEND(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, ListVariable) assert obj.is_mutable() - obj.call_method(self, "extend", [v], {}) + obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type] def LIST_TO_TUPLE(self, inst: Instruction) -> None: self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type] @@ -3673,7 +3673,7 @@ def MATCH_SEQUENCE(self, inst: Instruction) -> None: def MATCH_KEYS(self, inst: Instruction) -> None: tos = self.stack[-1] assert isinstance(tos, TupleVariable) - keys = tos.unpack_var_sequence(self) + keys = tos.unpack_var_sequence(self) # type: ignore[arg-type] tos1 = self.stack[-2] assert isinstance(tos1, ConstDictVariable) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 3f52c19ff0a90..318d0e91a0700 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1513,7 +1513,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. - self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined] + self.ctx.reconstruct_type(codegen) # type: ignore[union-attr] if codegen.tx.output.partial_convert: if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) @@ -1522,10 +1522,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # We rely on classes subtyping `GenericContextWrappingVariable` # to implement these fns and have these attributes codegen.extend_output( - [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type] + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[union-attr] ) codegen.extend_output( - create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type] + create_call_function(len(self.ctx.target_values), False) # type: ignore[union-attr] ) codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP")) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index be765cbbc8bf9..bdb37da3ccce1 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -82,7 +82,8 @@ def call_function( for item in itertools.product(*seqs, repeat=r) ] return variables.ListIteratorVariable( - items, mutation_type=ValueMutationNew() + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), ) elif ( self.value is itertools.combinations @@ -98,7 +99,8 @@ def call_function( for item in itertools.combinations(iterable, r): items.append(variables.TupleVariable(list(item))) return variables.ListIteratorVariable( - items, mutation_type=ValueMutationNew() + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), ) elif self.value is itertools.groupby: if any(kw != "key" for kw in kwargs.keys()): @@ -181,7 +183,8 @@ def keyfunc(x: VariableTracker) -> Any: from_exc=e, ) return variables.ListIteratorVariable( - result, mutation_type=ValueMutationNew() + result, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), ) elif self.value is itertools.repeat: if len(args) < 2: @@ -212,7 +215,8 @@ def keyfunc(x: VariableTracker) -> Any: ) ] return variables.ListIteratorVariable( - items, mutation_type=ValueMutationNew() + items, # type: ignore[arg-type] + mutation_type=ValueMutationNew(), ) else: return super().call_function(tx, args, kwargs) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 11a199e99eadc..e4731697868e5 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Variable tracking implementations for list-like data structures in Dynamo. @@ -20,7 +18,7 @@ class that handles its unique behaviors while integrating with Dynamo's import inspect import operator import sys -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, Sequence, TYPE_CHECKING import torch import torch.fx @@ -60,11 +58,11 @@ class that handles its unique behaviors while integrating with Dynamo's class BaseListVariable(VariableTracker): @staticmethod - def cls_for_instance(obj): + def cls_for_instance(obj: Any) -> type["BaseListVariable"]: return BaseListVariable.cls_for(type(obj)) @staticmethod - def cls_for(obj): + def cls_for(obj: Any) -> type: return { iter: ListIteratorVariable, list: ListVariable, @@ -80,34 +78,38 @@ def cls_for(obj): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance(items, list) assert all(isinstance(x, VariableTracker) for x in items) self.items: list[VariableTracker] = items - def _as_proxy(self): + def _as_proxy(self) -> list[Any]: return [x.as_proxy() for x in self.items] - def modified(self, items, **kwargs): + def modified( + self, items: list[VariableTracker], **kwargs: Any + ) -> "BaseListVariable": return type(self)(items, **kwargs) @property - def value(self): + def value(self) -> Any: return self.as_python_constant() - def debug_repr_helper(self, prefix, suffix): + def debug_repr_helper(self, prefix: str, suffix: str) -> str: return prefix + ", ".join(i.debug_repr() for i in self.items) + suffix - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.python_type()([x.as_python_constant() for x in self.items]) - def as_proxy(self): + def as_proxy(self) -> Any: assert self.python_type() is not SizeVariable return self.python_type()(self._as_proxy()) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: from .tensor import SymNodeVariable if isinstance(arg, SymNodeVariable): @@ -134,16 +136,16 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): IndexError, tx, args=["list index out of range"] ) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return list(self.items) def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": from .tensor import TensorVariable @@ -224,15 +226,15 @@ def call_method( if type(self) is not type(args[0]): tp_name = self.python_type_name() other = args[0].python_type_name() - msg = ConstantVariable.create( + msg_vt = ConstantVariable.create( f'can only concatenate {tp_name} (not "{other}") to {tp_name}' ) - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception(TypeError, tx, args=[msg_vt]) if name == "__add__": - return type(self)(self.items + args[0].items, source=self.source) + return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined] else: - self.items += args[0].items + self.items += args[0].items # type: ignore[attr-defined] return self elif name in ("__mul__", "__imul__"): if kwargs or len(args) != 1: @@ -244,10 +246,10 @@ def call_method( ) if not (args[0].is_python_constant() and args[0].python_type() is int): - msg = ConstantVariable.create( + msg_vt = ConstantVariable.create( f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'" ) - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception(TypeError, tx, args=[msg_vt]) val = args[0].as_python_constant() @@ -301,7 +303,7 @@ def call_method( class RangeVariable(BaseListVariable): - def __init__(self, items, **kwargs) -> None: + def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: items_to_map = items start = variables.ConstantVariable.create(0) stop = None @@ -316,7 +318,7 @@ def __init__(self, items, **kwargs) -> None: else: raise AssertionError - def maybe_as_int(x): + def maybe_as_int(x: VariableTracker) -> VariableTracker: return ( ConstantVariable(int(x.value)) if isinstance(x, ConstantVariable) else x ) @@ -329,22 +331,22 @@ def maybe_as_int(x): assert stop is not None super().__init__([start, stop, step], **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: return self.debug_repr_helper("range(", ")") - def python_type(self): + def python_type(self) -> type: return range - def start(self): + def start(self) -> Any: return self.items[0].as_python_constant() - def stop(self): + def stop(self) -> Any: return self.items[1].as_python_constant() - def step(self): + def step(self) -> Any: return self.items[2].as_python_constant() - def range_length(self): + def range_length(self) -> int: lo = self.start() hi = self.stop() step = self.step() @@ -357,7 +359,7 @@ def range_length(self): else: return 0 - def _get_slice_indices(self, length, slice): + def _get_slice_indices(self, length: int, slice: slice) -> list[int]: step_is_negative = 0 if slice.step is None: @@ -406,7 +408,7 @@ def _get_slice_indices(self, length, slice): return [start, stop, step] - def apply_index(self, index): + def apply_index(self, index: int) -> VariableTracker: length = self.range_length() if index < 0: index = length + index @@ -421,12 +423,12 @@ def apply_index(self, index): return variables.ConstantVariable.create(self.start() + (index * self.step())) - def apply_slice(self, slice): + def apply_slice(self, slice: slice) -> "RangeVariable": (slice_start, slice_stop, slice_step) = self._get_slice_indices( self.range_length(), slice ) - def compute_item(index): + def compute_item(index: int) -> int: return self.start() + (index * self.step()) sub_step = self.step() * slice_step @@ -442,10 +444,12 @@ def compute_item(index): ) return result - def as_python_constant(self): + def as_python_constant(self) -> range: return range(*[x.as_python_constant() for x in self.items]) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c index = arg.as_python_constant() @@ -457,28 +461,30 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): msg = ConstantVariable("range indices must be integers or slices") raise_observed_exception(TypeError, tx, args=[msg]) - def as_proxy(self): + def as_proxy(self) -> range: return self.python_type()(*self._as_proxy()) - def unpack_var_sequence(self, tx=None): + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] = None + ) -> list[VariableTracker]: return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] def reconstruct(self, codegen: "PyCodegen") -> None: assert "range" not in codegen.tx.f_globals codegen.add_push_null( - lambda: codegen.append_output(codegen.create_load_python_module(range)) + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] ) codegen.foreach(self.items) codegen.extend_output(create_call_function(3, False)) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is range: return variables.ConstantVariable.create(name in range.__dict__) return super().call_obj_hasattr(tx, name) - def range_equals(self, other: "RangeVariable"): + def range_equals(self, other: "RangeVariable") -> bool: r0, r1 = self, other if ( self.range_length() != r1.range_length() @@ -487,12 +493,12 @@ def range_equals(self, other: "RangeVariable"): ): return False - if len(r0) == 1: + if self.range_length() == 1: return True return r0.step() == r1.step() - def range_count(self, x: VariableTracker): + def range_count(self, x: VariableTracker) -> int: # Based on CPython # https://github.com/guilhermeleobas/cpython/blob/baefaa6cba1d69efd2f930cdc56bca682c54b139/Objects/rangeobject.c#L442-L486 x = x.as_python_constant() @@ -511,7 +517,13 @@ def range_count(self, x: VariableTracker): return int(re) return 0 - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__iter__": if not all(var.is_python_constant() for var in self.items): # Can't represent a `range_iterator` without well defined bounds @@ -545,7 +557,10 @@ def call_method(self, tx, name, args, kwargs): if pt is not range: return ConstantVariable.create(NotImplemented) - cmp = self.range_equals(other) + if isinstance(other, RangeVariable): + cmp = self.range_equals(other) + else: + cmp = False # Two ranges are equal if they produce the same sequence of values if name == "__eq__": @@ -554,7 +569,7 @@ def call_method(self, tx, name, args, kwargs): return ConstantVariable(not cmp) return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: fields = ["start", "stop", "step"] if name in fields: return self.items[fields.index(name)] @@ -568,11 +583,11 @@ class CommonListMethodsVariable(BaseListVariable): def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from .tensor import SymNodeVariable if name == "append" and self.is_mutable(): @@ -676,9 +691,9 @@ def call_method( self.items[key.evaluate_expr()] = value elif isinstance(key, SliceVariable): if key.is_python_constant(): - self.items[key.as_python_constant()] = list(value.items) + self.items[key.as_python_constant()] = list(value.items) # type: ignore[attr-defined] else: - items = slice( + items_slice = slice( *[ ( s.evaluate_expr() @@ -688,7 +703,7 @@ def call_method( for s in key.items ] ) - self.items[items] = list(value.items) + self.items[items_slice] = list(value.items) # type: ignore[attr-defined] else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) @@ -733,8 +748,8 @@ def call_method( "0 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - items = list(self.items) - return self.modified(items, mutation_type=ValueMutationNew()) + items_lst: list[VariableTracker] = list(self.items) + return self.modified(items_lst, mutation_type=ValueMutationNew()) elif name == "reverse" and self.is_mutable(): if args or kwargs: raise_args_mismatch( @@ -763,13 +778,13 @@ def call_method( class ListVariable(CommonListMethodsVariable): - def python_type(self): + def python_type(self) -> type: return list def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)})" - def debug_repr(self): + def debug_repr(self) -> str: return self.debug_repr_helper("[", "]") def reconstruct(self, codegen: "PyCodegen") -> None: @@ -778,11 +793,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from .tensor import SymNodeVariable if name == "__setitem__" and self.is_mutable(): @@ -805,14 +820,14 @@ def call_method( msg = ConstantVariable.create("can only assign an iterable") raise_observed_exception(TypeError, tx, args=[msg]) - key = key.as_python_constant() - if key.step == 0: + key_as_const = key.as_python_constant() + if key_as_const.step == 0: msg = ConstantVariable.create("slice step cannot be zero") raise_observed_exception(ValueError, tx, args=[msg]) - value = value.force_unpack_var_sequence(tx) + value_unpack = value.force_unpack_var_sequence(tx) try: - self.items[key] = value + self.items[key_as_const] = value_unpack except Exception as exc: raise_observed_exception( type(exc), @@ -859,7 +874,7 @@ def call_method( assert first_non_constant_key is not None try: - python_type = first_non_constant_key.python_type() + python_type = str(first_non_constant_key.python_type()) except NotImplementedError: python_type = "unknown" @@ -904,7 +919,7 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx, name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__class__": source = AttrSource(self.source, name) if self.source else None class_type = self.python_type() @@ -916,14 +931,19 @@ def var_getattr(self, tx, name): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is not list: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr([], name)) class DequeVariable(CommonListMethodsVariable): - def __init__(self, items, maxlen=None, **kwargs) -> None: + def __init__( + self, + items: list[VariableTracker], + maxlen: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: if maxlen is None: maxlen = ConstantVariable.create(None) assert maxlen.is_python_constant(), ( @@ -935,17 +955,17 @@ def __init__(self, items, maxlen=None, **kwargs) -> None: items = items[-maxlen.as_python_constant() :] super().__init__(items, **kwargs) - def python_type(self): + def python_type(self) -> type: return collections.deque - def debug_repr(self): + def debug_repr(self) -> str: if self.maxlen.as_python_constant() is None: return self.debug_repr_helper( "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" ) return self.debug_repr_helper("deque([", "])") - def as_python_constant(self): + def as_python_constant(self) -> collections.deque[Any]: return self.python_type()( [x.as_python_constant() for x in self.items], maxlen=self.maxlen.as_python_constant(), @@ -954,7 +974,7 @@ def as_python_constant(self): def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.append_output( - codegen.create_load_python_module(collections.deque) + codegen.create_load_python_module(collections.deque) # type: ignore[arg-type] ) ) codegen.foreach(self.items) @@ -962,18 +982,18 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.maxlen) codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "maxlen": return self.maxlen return super().var_getattr(tx, name) def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if ( name == "__setitem__" and self.is_mutable() @@ -1068,20 +1088,20 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is collections.deque: return variables.ConstantVariable.create(name in collections.deque.__dict__) return super().call_obj_hasattr(tx, name) class TupleVariable(BaseListVariable): - def python_type(self): + def python_type(self) -> type[tuple]: # type: ignore[type-arg] return tuple def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)})" - def debug_repr(self): + def debug_repr(self) -> str: return self.debug_repr_helper("(", ")") def reconstruct(self, codegen: "PyCodegen") -> None: @@ -1090,14 +1110,14 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx, name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__class__": source = AttrSource(self.source, name) if self.source else None class_type = self.python_type() @@ -1109,7 +1129,7 @@ def var_getattr(self, tx, name): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is not tuple: return super().call_obj_hasattr(tx, name) return variables.ConstantVariable.create(hasattr((), name)) @@ -1127,18 +1147,18 @@ def __init__( self, items: list[VariableTracker], proxy: Optional[torch.fx.Proxy] = None, - **kwargs, + **kwargs: Any, ) -> None: self.proxy = proxy super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: return self.debug_repr_helper("torch.Size([", "])") - def python_type(self): + def python_type(self) -> type: return torch.Size - def as_proxy(self): + def as_proxy(self) -> Any: if self.proxy is not None: return self.proxy @@ -1193,10 +1213,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: ] + create_call_function(1, False) codegen.extend_output(build_torch_size) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return list(self.items) - def numel(self, tx): + def numel(self, tx: "InstructionTranslator") -> VariableTracker: from .builtin import BuiltinVariable from .tensor import SymNodeVariable @@ -1226,11 +1246,11 @@ def numel(self, tx): def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": if kwargs or len(args) != 1: raise_args_mismatch( @@ -1253,7 +1273,9 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker): + def get_item_dyn( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: from .tensor import SymNodeVariable if isinstance(arg, SymNodeVariable): @@ -1269,7 +1291,7 @@ def get_item_dyn(self, tx: "InstructionTranslator", arg: VariableTracker): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: return variables.ConstantVariable.create(hasattr(torch.Size, name)) @@ -1280,33 +1302,39 @@ class NamedTupleVariable(TupleVariable): *TupleVariable._nonvar_fields, } - def __init__(self, items, tuple_cls, dynamic_attributes=None, **kwargs) -> None: + def __init__( + self, + items: list[VariableTracker], + tuple_cls: type, + dynamic_attributes: Optional[dict[str, VariableTracker]] = None, + **kwargs: Any, + ) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} - def is_namedtuple(self): + def is_namedtuple(self) -> bool: return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( getattr(self.tuple_cls, "_make", None) ) - def is_structseq(self): + def is_structseq(self) -> bool: return not self.is_namedtuple() - def fields(self): + def fields(self) -> tuple[str, ...]: return namedtuple_fields(self.tuple_cls) - def debug_repr(self): + def debug_repr(self) -> str: if self.is_structseq(): # StructSequenceType(iterable) return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) # NamedTupleType(*iterable) return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) - def python_type(self): + def python_type(self) -> type: return self.tuple_cls - def as_python_constant(self): + def as_python_constant(self) -> Any: if self.is_structseq(): # StructSequenceType(iterable) result = self.python_type()([x.as_python_constant() for x in self.items]) @@ -1328,7 +1356,7 @@ def as_python_constant(self): return result - def as_proxy(self): + def as_proxy(self) -> Any: assert self.python_type() is not SizeVariable if self.is_structseq(): # StructSequenceType(iterable) @@ -1342,7 +1370,10 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # StructSequenceType(iterable) # NamedTupleType(*iterable) # NamedTupleType._make(iterable) - create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make + if self.is_structseq(): + create_fn = self.tuple_cls + else: + create_fn = self.tuple_cls._make # type: ignore[attr-defined] codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_const_unchecked(create_fn) @@ -1384,8 +1415,8 @@ def _is_method_overridden(self, method_name: str) -> bool: def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: @@ -1446,7 +1477,9 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: if isinstance(arg, SliceVariable): # slicing a namedtuple produces a tuple return TupleVariable( @@ -1455,8 +1488,8 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): ) return super().getitem_const(tx, arg) - def var_getattr(self, tx: "InstructionTranslator", name): - def check_and_create_method(): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def check_and_create_method() -> Optional[VariableTracker]: method = inspect.getattr_static(self.tuple_cls, name, None) if isinstance(method, classmethod): # We need the unbounded cls method to avoid the inline __self__ @@ -1489,8 +1522,8 @@ def check_and_create_method(): return super().var_getattr(tx, name) if name == "_fields": - source = NamedTupleFieldsSource(self.source) if self.source else None - return VariableTracker.build(tx, self.fields(), source=source) + result_source = NamedTupleFieldsSource(self.source) if self.source else None + return VariableTracker.build(tx, self.fields(), source=result_source) if name in self.dynamic_attributes: return self.dynamic_attributes[name] @@ -1505,14 +1538,19 @@ def check_and_create_method(): def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: return variables.ConstantVariable.create( name in self.dynamic_attributes or hasattr(self.tuple_cls, name) ) class SliceVariable(VariableTracker): - def __init__(self, items, tx=None, **kwargs) -> None: + def __init__( + self, + items: Sequence[VariableTracker], + tx: Optional["InstructionTranslator"] = None, + **kwargs: Any, + ) -> None: items_to_map = items start, stop, step = [variables.ConstantVariable.create(None)] * 3 @@ -1547,23 +1585,23 @@ def __init__(self, items, tx=None, **kwargs) -> None: super().__init__(**kwargs) - def debug_repr(self): - return self.debug_repr_helper("slice(", ")") + def debug_repr(self) -> str: + return "slice(" + ", ".join(i.debug_repr() for i in self.items) + ")" - def as_proxy(self): + def as_proxy(self) -> slice: return slice(*[x.as_proxy() for x in self.items]) - def python_type(self): + def python_type(self) -> type: return slice - def as_python_constant(self): + def as_python_constant(self) -> slice: return slice(*[guard_if_dyn(x) for x in self.items]) def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) fields = ["start", "stop", "step"] @@ -1584,7 +1622,9 @@ class ListIteratorVariable(IteratorVariable): *IteratorVariable._nonvar_fields, } - def __init__(self, items, index: int = 0, **kwargs) -> None: + def __init__( + self, items: list[VariableTracker], index: int = 0, **kwargs: Any + ) -> None: super().__init__(**kwargs) assert isinstance(items, list) # Removing this check as it slows things down too much @@ -1598,7 +1638,7 @@ def __init__(self, items, index: int = 0, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() old_index = self.index if old_index >= len(self.items) or self.is_exhausted: @@ -1609,27 +1649,31 @@ def next_variable(self, tx): self.index += 1 return self.items[old_index] - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: return variables.ConstantVariable.create(hasattr(iter([]), name)) - def python_type(self): + def python_type(self) -> type: return type(iter([])) - def as_python_constant(self): + def as_python_constant(self) -> Any: if self.index > 0: raise NotImplementedError return iter([x.as_python_constant() for x in self.items]) - def has_unpack_var_sequence(self, tx): + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return True - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: if self.is_exhausted: return [] self.is_exhausted = True return list(self.items[self.index :]) - def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: return self.unpack_var_sequence(tx) def reconstruct(self, codegen: "PyCodegen") -> None: @@ -1656,27 +1700,37 @@ class RangeIteratorVariable(IteratorVariable): "iter_obj", } - def __init__(self, start: int, stop: int, step: int, len_: int, **kwargs): + def __init__( + self, start: int, stop: int, step: int, len_: int, **kwargs: Any + ) -> None: super().__init__(**kwargs) self.start = start self.stop = stop self.step = step self.len = len_ - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__next__": return self.next_variable(tx) elif name == "__iter__": return self return super().call_method(tx, name, args, kwargs) - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: if self.python_type() is range_iterator: ri = iter(range(0)) return ConstantVariable(hasattr(ri, name)) return super().call_obj_hasattr(tx, name) - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: if self.len <= 0: raise_observed_exception(StopIteration, tx) @@ -1685,12 +1739,12 @@ def next_variable(self, tx): self.start += self.step return ConstantVariable.create(current) - def python_type(self): + def python_type(self) -> type: return range_iterator - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( - lambda: codegen.append_output(codegen.create_load_python_module(range)) + lambda: codegen.append_output(codegen.create_load_python_module(range)) # type: ignore[arg-type] ) codegen.append_output(codegen.create_load_const(self.start)) codegen.append_output(codegen.create_load_const(self.stop)) From 9a86ef763201e27f031469f0866c893707e9cf38 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 7 Nov 2025 00:40:45 +0000 Subject: [PATCH 163/651] [BE][Typing][Dynamo] Type torch/_dynamo/variables/functions.py (#167103) Provides type coverage to torch/_dynamo/variables/dicts.py Coverage report: `mypy torch/_dynamo/variables/functions.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 2698 lines and 166 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167103 Approved by: https://github.com/mlazos, https://github.com/fxdawnn --- torch/_dynamo/variables/builtin.py | 2 +- torch/_dynamo/variables/functions.py | 773 +++++++++++++++++---------- torch/_dynamo/variables/iter.py | 2 +- torch/_dynamo/variables/torch.py | 9 +- 4 files changed, 485 insertions(+), 301 deletions(-) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 0f198377605ec..579cf7bfffc3d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1991,7 +1991,7 @@ def call_iter( # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() # with an integer argument starting at 0, until __getitem__ raises IndexError ret = variables.UserFunctionVariable( - polyfills.builtins.iter_ + polyfills.builtins.iter_ # type: ignore[arg-type] ).call_function(tx, [obj, *args], {}) if args: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 0752a413fce6e..c4865bfdedbfc 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Function-related variable tracking classes for Dynamo's symbolic execution. @@ -32,13 +30,14 @@ import traceback import types from collections.abc import Callable, Sequence -from types import FunctionType +from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from weakref import WeakKeyDictionary import torch from torch._dynamo.exc import get_stack_above_dynamo +from torch._guards import Source from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -87,25 +86,32 @@ try: from torch.distributed.fsdp._fully_shard import _fsdp_param_group except ModuleNotFoundError: - _fsdp_param_group = None + _fsdp_param_group = None # type: ignore[assignment] if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen - from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.ctx_manager import ContextWrappingVariable from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, TritonKernelType, ) + from .lists import BaseListVariable, ListVariable + from .tensor import TensorVariable -_F = TypeVar("_F", bound=Callable) + +_F = TypeVar("_F", bound=Callable[..., Any]) CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 # Module-level cache keyed by the function object -_spec_cache = WeakKeyDictionary() +_spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() class FunctionSpec: @@ -127,7 +133,7 @@ def __init__(self, func: FunctionType): off += 1 if self.varargs_name else 0 self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None - def update_defaults(self, func: FunctionType): + def update_defaults(self, func: FunctionType) -> None: # Defaults can change from function call to function call. So re-update # them on every call. self.defaults = func.__defaults__ or () @@ -147,7 +153,13 @@ def _get_spec(func: FunctionType) -> FunctionSpec: return spec -def bind_args_cached(func, tx, fn_source, args, kwargs): +def bind_args_cached( + func: FunctionType, + tx: "InstructionTranslator", + fn_source: Optional[Source], + args: Sequence[Any], + kwargs: dict[str, Any], +) -> dict[str, VariableTracker]: spec = _get_spec(func) spec.update_defaults(func) ba = {} @@ -240,7 +252,9 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): return ba -def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): +def wrap_bound_arg( + tx: "InstructionTranslator", val: Any, source: Optional[Source] = None +) -> VariableTracker: # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): return val @@ -252,14 +266,18 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): return variables.LazyVariableTracker.create(val, source) -def wrap_args_kwargs(tx: "InstructionTranslator", result): +def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None: for k, v in list(result.items()): if isinstance(v, (tuple, dict)): # args/kwargs result[k] = wrap_bound_arg(tx, v) -def init_cellvars(parent, result: dict[str, VariableTracker], code): +def init_cellvars( + parent: "InstructionTranslator", + result: dict[str, VariableTracker], + code: types.CodeType, +) -> None: """ Update `result` to add mapping from local name to new cells created directly by `code`, or update SideEffects in `parent` if the a local cell is @@ -277,8 +295,14 @@ def init_cellvars(parent, result: dict[str, VariableTracker], code): def _create_nested_fn( - code, f_globals, name, defaults, closure, kwdefaults, annotations -): + code: types.CodeType, + f_globals: dict[str, Any], + name: str, + defaults: Optional[tuple[object, ...]], + closure: Optional[tuple[CellType]], + kwdefaults: Optional[dict[str, Any]], + annotations: Optional[dict[str, Any]], +) -> types.FunctionType: from types import FunctionType func = FunctionType(code, f_globals, name, defaults, closure) @@ -291,7 +315,7 @@ def _create_nested_fn( # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) - func.__annotations__ = annotations + func.__annotations__ = annotations # type: ignore[assignment] return func @@ -307,7 +331,9 @@ def _create_nested_fn( } -def fn_var_getattr(tx, fn, source, name): +def fn_var_getattr( + tx: "InstructionTranslator", fn: object, source: Optional[Source], name: str +) -> VariableTracker: source = source and AttrSource(source, name) if source and name == "__annotations__": @@ -316,6 +342,7 @@ def fn_var_getattr(tx, fn, source, name): # graph is even rarer. So skip guards. source = SkipGuardSource(source) + subobj = None try: subobj = inspect.getattr_static(fn, name) except AttributeError: @@ -332,19 +359,19 @@ def fn_var_getattr(tx, fn, source, name): class BaseUserFunctionVariable(VariableTracker): - def get_filename(self): - return self.get_code().co_filename + def get_filename(self) -> str: + return self.get_code().co_filename # type: ignore[attr-defined] - def get_name(self): - return self.get_code().co_name + def get_name(self) -> str: + return self.get_code().co_name # type: ignore[attr-defined] def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -352,16 +379,16 @@ def call_obj_hasattr( result = False try: - result = hasattr(self.get_function(), name) + result = hasattr(self.get_function(), name) # type: ignore[attr-defined] except NotImplementedError: if name == "__name__" and isinstance(self, NestedUserFunctionVariable): result = True return variables.ConstantVariable.create(result) - def inspect_parameter_names(self): - return list(inspect.signature(self.get_function()).parameters) + def inspect_parameter_names(self) -> list[str]: + return list(inspect.signature(self.get_function()).parameters) # type: ignore[attr-defined] - def closure_vars(self, tx): + def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]: return {} @@ -375,11 +402,16 @@ class UserFunctionVariable(BaseUserFunctionVariable): } @classmethod - def create_with_source(cls, value, source): + def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__(self, fn, is_constant=False, **kwargs) -> None: + def __init__( + self, + fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg] + is_constant: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if getattr(fn, "_dynamo_marked_constant", False): # This method should be treated as a constant for the purposes of compilation @@ -403,40 +435,45 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) - self.fn: types.FunctionType = fn + self.fn = fn - def as_python_constant(self): + def as_python_constant(self) -> Any: if istype(self, UserFunctionVariable): return self.fn # subclasses (such as methods) usually aren't a constant return super().as_python_constant() - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] - def get_function(self): + def get_function(self) -> types.FunctionType: return self.fn - def get_code(self): + def get_code(self) -> types.CodeType: return self.fn.__code__ - def python_type(self): + def python_type(self) -> type: return types.FunctionType - def has_self(self): + def has_self(self) -> bool: return getattr(self.fn, "__self__", None) is not None - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.fn.__globals__ - def get_source(self): + def get_source(self) -> Source: source = self.source if source and isinstance(self, variables.UserMethodVariable): - source = self.source_fn - return source + source = self.source_fn # type: ignore[assignment] + return source # type: ignore[return-value] - def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to this function, create new bindings for initial locals. @@ -450,7 +487,7 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: root_tx = parent.output.root_tx source = self.get_source() - result = bind_args_cached(fn, root_tx, source, args, kwargs) + result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type] init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -491,7 +528,7 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: return result - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) source = self.get_source() @@ -506,9 +543,9 @@ def call_obj_hasattr( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # Handle patch_dynamo_config call if self.fn is torch._dynamo.patch_dynamo_config: try: @@ -548,7 +585,7 @@ def call_function( msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" unimplemented_v2( gb_type="TypeError from user code", - context=f"call_function({self.value}, {args}, {kwargs})", + context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] explanation=msg, hints=[ *graph_break_hints.USER_ERROR, @@ -567,7 +604,7 @@ def call_function( "`torch.compile` region", ], ) - + # pyrefly: ignore[missing-attribute] fn = fn_var.fn return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) @@ -593,7 +630,7 @@ def call_function( try: from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState except Exception: - FSDPState = None + FSDPState = None # type: ignore[assignment, misc] if FSDPState is not None and self.fn in [ FSDPState._pre_forward, FSDPState._post_forward, @@ -604,13 +641,15 @@ def call_function( class BuiltinMethodVariable(BaseUserFunctionVariable): - def __init__(self, fn, is_constant=False, **kwargs) -> None: + def __init__( + self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any + ) -> None: super().__init__(**kwargs) assert isinstance(fn, types.BuiltinMethodType) self.fn = fn @staticmethod - def is_supported_builtin_method(obj): + def is_supported_builtin_method(obj: Any) -> bool: method_self = obj.__self__ method_name = obj.__name__ @@ -623,9 +662,9 @@ def is_supported_builtin_method(obj): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: method_self = self.fn.__self__ name = self.fn.__name__ obj_source = self.source and AttrSource(self.source, "__self__") @@ -637,39 +676,39 @@ class LocalGeneratorObjectVariable(VariableTracker): def __init__( self, code: types.CodeType, - f_globals, + f_globals: dict[str, Any], inline_tracer: Optional["InstructionTranslator"], - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer - def get_code(self): + def get_code(self) -> types.CodeType: return self.code - def get_filename(self): + def get_filename(self) -> str: return self.get_code().co_filename - def get_name(self): + def get_name(self) -> str: return self.get_code().co_name - def get_function(self): + def get_function(self) -> Never: raise NotImplementedError - def has_self(self): + def has_self(self) -> bool: return False - def __name__(self): + def __name__(self) -> str: return self.get_name() - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.get_name()})" __repr__ = __str__ - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -688,25 +727,30 @@ def reconstruct(self, codegen: "PyCodegen"): self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) - def bind_args(self, tx, args, kwargs): - return self.fn.bind_args(tx, args, kwargs) + def bind_args( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + return self.vt.bind_args(tx, args, kwargs) # type: ignore[attr-defined] - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.f_globals - def python_type(self): + def python_type(self) -> type: return types.GeneratorType - def _get_inline_tracer(self, tx): + def _get_inline_tracer(self, tx: "InstructionTranslator") -> Any: from torch._dynamo.symbolic_convert import InliningInstructionTranslator if self.inline_tracer is None: - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( # type: ignore[assignment] tx, self, [], {} ) return self.inline_tracer - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: tracer = self._get_inline_tracer(tx) if self._is_generator_exhausted(): @@ -727,23 +771,29 @@ def next_variable(self, tx): torch._dynamo.eval_frame.skip_code(self.get_code()) raise SkipFrame from e - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) - def has_unpack_var_sequence(self, tx): + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return False - def has_force_unpack_var_sequence(self, tx) -> builtins.bool: + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return True - def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: - result = [] + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] self.force_apply_to_var_sequence(tx, result.append) return result - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[VariableTracker], Any] + ) -> None: while True: try: fn(self.next_variable(tx)) @@ -751,7 +801,9 @@ def force_apply_to_var_sequence(self, tx, fn) -> None: handle_observed_exception(tx) break - def _setup_exception(self, tx, exc): + def _setup_exception( + self, tx: "InstructionTranslator", exc: VariableTracker + ) -> None: tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) @@ -760,19 +812,19 @@ def _setup_exception(self, tx, exc): # exception is raised again. tracer.exception_handler(e) - def _is_generator_just_started(self): + def _is_generator_just_started(self) -> bool: return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 - def _is_generator_exhausted(self): + def _is_generator_exhausted(self) -> bool: return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( self, tx: "InstructionTranslator", name: str, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__next__": return self.next_variable(tx) elif name == "__iter__": @@ -952,7 +1004,7 @@ def call_method( raise_observed_exception(RuntimeError, tracer) return retval - super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) class ContextlibContextManagerLocalGeneratorObjectVariable( @@ -980,19 +1032,24 @@ def __init__( self, vt: VariableTracker, *, - generator_cls=LocalGeneratorObjectVariable, - **kwargs, - ): + generator_cls: type = LocalGeneratorObjectVariable, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.vt = vt self.generator_cls = generator_cls - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in self.__class__.__dict__.keys(): return getattr(self, name) return getattr(self.vt, name) - def _build_inline_tracer(self, tx, args, kwargs): + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InstructionTranslatorBase": from torch._dynamo.symbolic_convert import InliningInstructionTranslator return InliningInstructionTranslator.build_inline_tracer( @@ -1005,13 +1062,13 @@ def _build_inline_tracer(self, tx, args, kwargs): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - if not is_generator(self.vt.get_code()): + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] unimplemented_v2( gb_type="non-generator contextlib.contextmanager", - context=str(self.vt.get_code()), + context=str(self.vt.get_code()), # type: ignore[attr-defined] explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" ", i.e. does not use `yield`", hints=[ @@ -1020,15 +1077,15 @@ def call_function( ], ) - inline_tracer = self._build_inline_tracer(tx, args, kwargs) - code = self.vt.get_code() - f_globals = self.vt.get_globals() + inline_tracer = self._build_inline_tracer(tx, list(args), kwargs) + code = self.vt.get_code() # type: ignore[attr-defined] + f_globals = self.vt.get_globals() # type: ignore[attr-defined] # calling a generator returns a generator object return self.generator_cls( code, f_globals, - inline_tracer, + inline_tracer, # type: ignore[arg-type] source=self.source, ) @@ -1042,14 +1099,19 @@ class FunctionDecoratedByContextlibContextManagerVariable( This is only used when the function is annotated with @contextlib.contextmanager """ - def __init__(self, vt, **kwargs): + def __init__(self, vt: VariableTracker, **kwargs: Any): super().__init__( vt, generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, **kwargs, ) - def _build_inline_tracer(self, tx, args, kwargs): + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InstructionTranslatorBase": # NOTE: This only exists to not break support for context manager when # config.enable_faithful_generator_behavior = False and # config.enable_trace_contextlib = True. In case the former is false, @@ -1066,8 +1128,14 @@ def _build_inline_tracer(self, tx, args, kwargs): class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" - def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: - super().__init__(fn=fn, **kwargs) + def __init__( + self, + fn: Callable[..., Any], + obj: VariableTracker, + source_fn: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(fn=fn, **kwargs) # type: ignore[arg-type] self.obj = obj self.source_fn = source_fn # Note on source and source_fn @@ -1083,24 +1151,24 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: # operates on the unbound function, most guards should target # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: - self.source_fn = AttrSource(kwargs.get("source"), "__func__") + self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type] def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [self.obj] - def python_type(self): + def python_type(self) -> type[types.MethodType]: return types.MethodType def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NOTE this is to handle methods annotated by `nonstrict_trace`. # a `nonstrict_trace`-ed function will be wrapped by # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, # but in the case of method, we manually wrap it with `UserMethodVariable` @@ -1141,36 +1209,41 @@ def call_function( or self.is_constant ): return self.obj.call_method( - tx, self.fn.__name__, args, kwargs, constant=self.is_constant + tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant ) elif ( _fsdp_param_group is not None - and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined] ): return variables.TorchCtxManagerClassVariable(self.fn).call_function( tx, (self.obj, *args), kwargs ) if self.is_constant: - fn = getattr(self.obj.value, self.fn.__name__) + fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined] return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) - def inspect_parameter_names(self): + def inspect_parameter_names(self) -> list[str]: return super().inspect_parameter_names()[1:] - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__self__": return self.obj if name == "__func__": # We might have a better way to access the function object, this # information is stored in self.source_fn, use that to construct the # variable tracker. - return VariableTracker.build(tx, self.fn, self.source_fn) + return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] return super().var_getattr(tx, name) class WrappedUserMethodVariable(UserMethodVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: UserMethodVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn", None) kwargs.pop("obj", None) super().__init__(wrapped.fn, wrapped.obj, **kwargs) @@ -1180,22 +1253,27 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) class WrappedUserFunctionVariable(UserFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: UserFunctionVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped @@ -1204,22 +1282,28 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) -def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): - def convert(x): +def invoke_and_store_as_constant( + tx: "InstructionTranslator", + fn: Callable[..., Any], + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> VariableTracker: + def convert(x: VariableTracker) -> Any: if isinstance(x, variables.TensorVariable): return x.get_real_value() return x.as_python_constant() @@ -1242,17 +1326,17 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def __init__( self, - fn_name, - code, - f_globals, - defaults, - kwdefaults, - annotations, - closure, + fn_name: VariableTracker, + code: VariableTracker, + f_globals: dict[str, Any], + defaults: Optional[VariableTracker], + kwdefaults: Optional[VariableTracker], + annotations: Optional[VariableTracker], + closure: Optional[VariableTracker], # This is present when this function is created by # `functools.wrap(wrapped_fn)(this_fn)`. - wrapped_fn=None, - **kwargs, + wrapped_fn: Optional[VariableTracker] = None, + **kwargs: Any, ) -> None: if kwargs.get("mutation_type") is None: kwargs.update(mutation_type=AttributeMutationNew()) @@ -1269,16 +1353,16 @@ def __init__( self.closure = closure self.wrapped_fn: Optional[VariableTracker] = wrapped_fn - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] - def get_code(self): + def get_code(self) -> types.CodeType: return self.code.as_python_constant() - def python_type(self): + def python_type(self) -> type: return types.FunctionType - def get_function(self): + def get_function(self) -> types.FunctionType: if self.closure: raise NotImplementedError func = types.FunctionType( @@ -1307,19 +1391,25 @@ def call_setattr( tx: "InstructionTranslator", name_var: VariableTracker, val: VariableTracker, - ): - tx.output.side_effects.store_attr(self, name_var.value, val) + ) -> VariableTracker: + tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined] return ConstantVariable(None) - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__setattr__": return self.call_setattr(tx, *args) - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, list(args), kwargs) - def has_closure(self): + def has_closure(self) -> bool: return self.closure is not None - def const_getattr(self, tx, name): + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: if name == "__name__": return self.get_name() if name == "__code__": @@ -1329,50 +1419,57 @@ def const_getattr(self, tx, name): return d.as_python_constant() if d else None return super().const_getattr(tx, name) - def call_obj_hasattr(self, tx: "InstructionTranslator", name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: if name == "__code__": return variables.ConstantVariable.create(hasattr(self, "code")) if name == "__defaults__": return variables.ConstantVariable.create(hasattr(self, "defaults")) return super().call_obj_hasattr(tx, name) - def has_self(self): + def has_self(self) -> bool: return False - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.f_globals - def bind_args(self, parent, args, kwargs): + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: code = self.get_code() func = types.FunctionType( code, self.f_globals, self.fn_name.as_python_constant(), - tuple(self.defaults.items) if self.defaults else None, + tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined] tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), ) if self.kwdefaults: - func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined] bound = inspect.signature(func).bind(*args, **kwargs) bound.apply_defaults() result = dict(bound.arguments.items()) - wrap_args_kwargs(parent.output.root_tx, result) + wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type] init_cellvars(parent, result, code) for idx, name in enumerate(code.co_freevars): assert name not in result - cell = self.closure.items[idx] + cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr] result[name] = cell return result - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) codegen(self.code) codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) - codegen(ConstantVariable.create(self.code.value.co_name)) + codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined] if self.defaults: codegen(self.defaults) @@ -1426,7 +1523,12 @@ def reconstruct(self, codegen: "PyCodegen"): class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: Any, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn_name", None) kwargs.pop("code", None) kwargs.pop("f_globals", None) @@ -1451,16 +1553,16 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1472,16 +1574,16 @@ class SkipFunctionVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, value, reason=None, **kwargs) -> None: + def __init__(self, value: Any, reason: Optional[str] = None, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value self.reason = reason - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value @classmethod - def create_with_source(cls, value, source): + def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": # Use closure match guard (i.e. guard on __code__ object instead of # function id) to avoid guarding on nested functions. if inspect.getattr_static(value, "_torchdynamo_disable", False): @@ -1510,9 +1612,9 @@ def create_with_source(cls, value, source): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if inspect.getattr_static(self.value, "_torchdynamo_disable", False): msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) unimplemented_v2( @@ -1525,7 +1627,7 @@ def call_function( ], ) elif self.value is torch._dynamo.graph_break: - graph_break_msg = kwargs.get("msg", None) + graph_break_msg = kwargs.get("msg") if graph_break_msg: graph_break_msg = graph_break_msg.as_python_constant() unimplemented_v2( @@ -1537,7 +1639,7 @@ def call_function( ], ) elif self.value is torch._dynamo.skip_frame: - skip_frame_msg = kwargs.get("msg", None) + skip_frame_msg = kwargs.get("msg") if skip_frame_msg: skip_frame_msg = skip_frame_msg.as_python_constant() raise SkipFrame( @@ -1629,10 +1731,12 @@ def call_function( hints=hints, ) - def call_obj_hasattr(self, tx: "InstructionTranslator", name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: return variables.ConstantVariable.create(hasattr(self.value, name)) - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) @@ -1640,26 +1744,31 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): class WrappedSkipFunctionVariable(SkipFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: VariableTracker, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("value", None) kwargs.pop("reason", None) - super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) # type: ignore[attr-defined] self.wrapped = wrapped self.context = context def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1672,12 +1781,12 @@ class WrapperUserFunctionVariable(VariableTracker): __script_if_tracing_wrapper have the original attr at "__original_fn". """ - def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: + def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.wrapper_obj = wrapper_obj self.attr_to_trace = attr_to_trace - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == self.attr_to_trace: val = getattr(self.wrapper_obj, self.attr_to_trace) source = self.source and AttrSource(self.source, name) @@ -1685,15 +1794,15 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if hasattr(self.wrapper_obj, "cache_info"): target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) module_name = getattr(target_fn, "__module__", "") or "" @@ -1719,9 +1828,9 @@ def call_function( user_stack_trace += str(user_stack_formatted) dynamo_logger.debug(user_stack_trace) - all_args = self.self_args() + args + all_args = self.self_args() + list(args) return variables.UserFunctionVariable( - polyfills.getattr_and_trace + polyfills.getattr_and_trace # type: ignore[arg-type] ).call_function( tx, [self, variables.ConstantVariable(self.attr_to_trace), *all_args], @@ -1736,15 +1845,21 @@ class WrapperUserMethodVariable(WrapperUserFunctionVariable): WrapperUserFunctionVariable in `call_function` method. """ - def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None: + def __init__( + self, + wrapper_obj: Any, + attr_to_trace: str, + self_obj: VariableTracker, + **kwargs: Any, + ) -> None: super().__init__(wrapper_obj, attr_to_trace, **kwargs) self.obj = self_obj - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [self.obj] -def _traceable_collective_remaps(): +def _traceable_collective_remaps() -> dict[Any, Any]: # We can't rely on importing from distributed, since it's not always built if torch.distributed.is_available(): from torch.distributed._functional_collectives import ( @@ -1755,7 +1870,9 @@ def _traceable_collective_remaps(): return {} -def _traceable_collectives_source(tx: "InstructionTranslator", fn): +def _traceable_collectives_source( + tx: "InstructionTranslator", fn: Callable[..., Any] +) -> AttrSource: assert torch.distributed.is_available(), "Illegal invocation." assert fn in _traceable_collective_remaps().values() @@ -1775,13 +1892,24 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable): than status-quo as we currently graph-break on all distributed.* collectives. """ - def __init__(self, fn, *, replacement_var, **kwargs) -> None: - super().__init__(fn, **kwargs) + def __init__( + self, + fn: Callable[..., Any], + *, + replacement_var: UserFunctionVariable, + **kwargs: Any, + ) -> None: + super().__init__(fn, **kwargs) # type: ignore[arg-type] assert isinstance(replacement_var, UserFunctionVariable) self.replacement_var = replacement_var @staticmethod - def create(tx: "InstructionTranslator", old_fn, source, **options): + def create( + tx: "InstructionTranslator", + old_fn: Callable[..., Any], + source: Source, + **options: Any, + ) -> "CollectiveFunctionRewriteVariable": new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) return CollectiveFunctionRewriteVariable( old_fn, @@ -1791,22 +1919,24 @@ def create(tx: "InstructionTranslator", old_fn, source, **options): ) @staticmethod - def can_rewrite(variable): + def can_rewrite(variable: Any) -> bool: return ( inspect.isfunction(variable) and variable in _traceable_collective_remaps() ) @staticmethod - def rewrite(tx: "InstructionTranslator", fn): + def rewrite( + tx: "InstructionTranslator", fn: Callable[..., Any] + ) -> tuple[Any, AttrSource]: new_fn = _traceable_collective_remaps()[fn] return new_fn, _traceable_collectives_source(tx, new_fn) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # call_function must check any unsupported arguments and graph-break. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, # since that's the contract for putting a mapping in `traceable_collective_remaps` @@ -1836,7 +1966,7 @@ def call_function( ): reduce_op_var = kwargs.get("op") reduce_op = ( - reduce_op_var.value + reduce_op_var.value # type: ignore[attr-defined] if reduce_op_var is not None else signature.parameters["op"].default ) @@ -1852,12 +1982,12 @@ class FunctoolsWrapsVariable(UserFunctionVariable): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if not kwargs and len(args) == 1: - def wraps(fn): + def wraps(fn: Any) -> VariableTracker: if isinstance(fn, variables.NestedUserFunctionVariable): return fn.clone(wrapped_fn=args[0]) unimplemented_v2( @@ -1875,15 +2005,15 @@ def wraps(fn): class CollectionsNamedTupleFunction(UserFunctionVariable): - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.fn def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: constant_args = check_constant_args(args, kwargs) if constant_args: try: @@ -1898,7 +2028,9 @@ def call_function( args=list(map(ConstantVariable.create, exc.args)), ) return variables.UserDefinedClassVariable( - value, mutation_type=ValueMutationNew() + # pyrefly: ignore[unbound-name] + value, + mutation_type=ValueMutationNew(), ) unimplemented_v2( gb_type="namedtuple construction", @@ -1911,7 +2043,13 @@ def call_function( class FunctoolsPartialVariable(VariableTracker): - def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: + def __init__( + self, + func: VariableTracker, + args: Sequence[VariableTracker], + keywords: dict[str, VariableTracker], + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.func = func assert isinstance(args, list) @@ -1922,10 +2060,10 @@ def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: # on it is sufficient for the tracing purposes. self.fake_value = functools.partial(identity) - def python_type(self): + def python_type(self) -> type: return functools.partial - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -1940,16 +2078,16 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) ) - def get_function(self): + def get_function(self) -> Any: return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - merged_args = self.args + args + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + merged_args = self.args + list(args) merged_kwargs = {**self.keywords, **kwargs} return self.func.call_function(tx, merged_args, merged_kwargs) @@ -1961,7 +2099,7 @@ def call_obj_hasattr( hasattr(functools.partial(identity), name) ) - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: source = self.source and AttrSource(self.source, name) # Handle __slots__ if name == "func": @@ -1975,14 +2113,14 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): return variables.GetAttrVariable(self, name) raise_observed_exception(AttributeError, tx) - def as_python_constant(self): + def as_python_constant(self) -> Any: return functools.partial( self.func.as_python_constant(), *[arg.as_python_constant() for arg in self.args], **{k: v.as_python_constant() for k, v in self.keywords.items()}, ) - def guard_as_python_constant(self): + def guard_as_python_constant(self) -> Any: """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" return functools.partial( self.func.guard_as_python_constant(), @@ -2005,16 +2143,20 @@ def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} @classmethod - def create_with_source(cls, value, source): + def create_with_source( + cls, value: Any, source: Source + ) -> "PolyfilledFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__(self, fn: _F, **kwargs) -> None: + def __init__(self, fn: _F, **kwargs: Any) -> None: super().__init__(**kwargs) + # pyrefly: ignore[invalid-type-var] self.fn: _F = fn handler = self._get_polyfill_handlers().get(fn, fn) + traceable_fn = None assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" for candidate_attr in ( "__torch_dynamo_polyfill__", # registered polyfill @@ -2029,28 +2171,29 @@ def __init__(self, fn: _F, **kwargs) -> None: raise RuntimeError( f"Polyfill handler {handler} does not have a traceable function" ) - - self.wrapped_fn: _F = handler + # pyrefly: ignore[invalid-type-var] + self.wrapped_fn = handler + # pyrefly: ignore[invalid-type-var] self.traceable_fn: _F = traceable_fn @property - def polyfill_fn(self) -> _F: + def polyfill_fn(self) -> Callable[..., Any]: return self.traceable_fn - def can_constant_fold_through(self): + def can_constant_fold_through(self) -> bool: return getattr( self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False ) - def get_function(self): + def get_function(self) -> Any: return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -2087,7 +2230,7 @@ def call_function( ( x.value if isinstance(x, variables.ConstantVariable) - else x.sym_num + else x.sym_num # type: ignore[attr-defined] ) for x in args[0].items ] @@ -2099,11 +2242,11 @@ def call_function( def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__call__": return self.call_function(tx, args, kwargs) @@ -2113,27 +2256,33 @@ def call_method( options = {} if self.source: options["source"] = AttrSource(self.source, name) + # pyrefly: ignore[bad-specialization] polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) return polyfilled_method_variable.call_function(tx, args, kwargs) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.fn class TracebackVariable(VariableTracker): # We don't track traceback. A call to any function in this module is a no-op - def call_function(self, tx, args, kwargs): ... + def call_function( # type: ignore[empty-body] + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ... class SysFunctionVariable(VariableTracker): - def __init__(self, value, **kwargs): + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value - def exc_info(self, tx): + def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": if len(tx.exn_vt_stack): exn = tx.exn_vt_stack[-1] - typ = exn.exc_type + typ = exn.exc_type # type: ignore[union-attr] tb = None items = [ VariableTracker.build(tx, typ), @@ -2146,12 +2295,17 @@ def exc_info(self, tx): variables.ConstantVariable(None), variables.ConstantVariable(None), ] - return variables.TupleVariable(items) + return variables.TupleVariable(items) # type: ignore[arg-type] - def exception(self, tx): + def exception(self, tx: "InstructionTranslator") -> VariableTracker: return self.exc_info(tx).items[1] - def call_function(self, tx, args, kwargs): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.value is sys.exc_info: return self.exc_info(tx) assert self.value is sys.exception @@ -2170,15 +2324,15 @@ class DynamoTritonHOPifier(TritonHOPifier): def raise_unsupported(self, msg: str) -> Never: raise Unsupported(msg) - def is_callable(self, maybe_callable: Any) -> bool: + def is_callable(self, maybe_callable: VariableTracker) -> bool: return isinstance( maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) ) - def get_value(self, val: Any) -> Any: - return val.value + def get_value(self, val: VariableTracker) -> Any: + return val.value # type: ignore[attr-defined] - def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: + def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]: from .lists import BaseListVariable if isinstance(grid, BaseListVariable): @@ -2193,20 +2347,35 @@ def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: ], ) - def call_grid(self, grid, meta, tx): - meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()} - grid = grid.call_function(tx, [meta], {}) + def call_grid( + self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator" + ) -> Any: + meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta_var], {}) return grid # We use this function to wrap call_prune_configs - def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable): + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + tx: Optional["InstructionTranslator"], + variable: Any, + ) -> VariableTracker: from .builder import SourcelessBuilder - wrapped_user_function = SourcelessBuilder.create(tx, user_fn) + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type] result = wrapped_user_function.call_function(tx, args, kwargs) return result - def wrap_user_defined_obj(self, user_obj, tx, variable, name): + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Any, + name: str, + ) -> VariableTracker: from .builder import VariableBuilder wrapped_user_obj = VariableBuilder( @@ -2214,7 +2383,9 @@ def wrap_user_defined_obj(self, user_obj, tx, variable, name): )._wrap(user_obj) return wrapped_user_obj - def maybe_unpack_configs(self, configs, tx): + def maybe_unpack_configs( + self, configs: Any, tx: Optional["InstructionTranslator"] + ) -> list[Any]: # unpack the list of configs configs = configs.unpack_var_sequence(tx) @@ -2223,7 +2394,7 @@ def maybe_unpack_configs(self, configs, tx): return configs - def maybe_unpack_heuristic_result(self, result: Any) -> Any: + def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: if not result.is_python_constant(): self.raise_unsupported( "@triton.heuristics must return constant values because configs can only contain constant values." @@ -2233,7 +2404,7 @@ def maybe_unpack_heuristic_result(self, result: Any) -> Any: # We need to override call_getitem here so that we can add the source in the case # where we call the triton kernel with a grid - def call_getitem( + def call_getitem( # type: ignore[override] self, variable: "TritonKernelVariable", args: Sequence[Any], @@ -2251,7 +2422,13 @@ def call_getitem( kernel_source=variable.source, ) - def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: + def call_HOP( + self, + variable: "TritonKernelVariable", + grids: Any, + combined_args_raw: dict[str, Any], + tx: "InstructionTranslator", + ) -> "variables.ConstantVariable": from .constant import ConstantVariable from .dicts import ConstDictVariable @@ -2330,7 +2507,9 @@ class TritonKernelVariable(VariableTracker): kernel_idx: Optional[int] kernel_source: "AttrSource" - def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: + def __init__( + self, kernel: Any, kernel_idx: Optional[int], grid: Any, **kwargs: Any + ) -> None: self.kernel_source = kwargs.pop("kernel_source", None) super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) @@ -2338,24 +2517,24 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - return dynamo_triton_hopifier_singleton.call_triton_kernel( + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value] self, args, kwargs, tx ) def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": return dynamo_triton_hopifier_singleton.call_getitem(self, args) elif name == "run": - return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] # Bail out to parent's implementation return super().call_method(tx, name, args, kwargs) @@ -2374,11 +2553,11 @@ class TMADescriptorExperimentalVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", - dims: "list[ConstantVariable]", - block_dims: "list[ConstantVariable]", - element_size: "ConstantVariable", - **kwargs, - ): + dims: list[VariableTracker], + block_dims: list[VariableTracker], + element_size: VariableTracker, + **kwargs: Any, + ) -> None: assert isinstance(data_ptr, variables.DataPtrVariable) super().__init__(**kwargs) self.data_ptr = data_ptr @@ -2386,14 +2565,14 @@ def __init__( self.block_dims = block_dims self.element_size = element_size - def to_metadata(self): + def to_metadata(self) -> Any: return create_tma_experimental_metadata( [dim.as_proxy() for dim in self.dims], [dim.as_proxy() for dim in self.block_dims], self.element_size.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", @@ -2405,28 +2584,28 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.foreach(args) codegen.call_function(len(args) + 1, False) - def get_tensor(self): + def get_tensor(self) -> VariableTracker: return self.data_ptr.from_tensor class TMADescriptorStableVariable(VariableTracker): def __init__( self, - tensor: "variables.TensorVariable", - block_shape: "variables.ListVariable", - **kwargs, - ): + tensor: "TensorVariable", + block_shape: "ListVariable", + **kwargs: Any, + ) -> None: assert isinstance(tensor, variables.TensorVariable) super().__init__(**kwargs) self.tensor = tensor self.block_shape = block_shape - def to_metadata(self): + def to_metadata(self) -> Any: return create_tma_stable_metadata( self.block_shape.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.tensor_descriptor", @@ -2438,7 +2617,7 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.block_shape) codegen.call_method(2) - def get_tensor(self) -> "variables.TensorVariable": + def get_tensor(self) -> Any: return self.tensor @@ -2446,7 +2625,7 @@ class CreateTMADescriptorExperimentalVariable(VariableTracker): def __init__( self, rank: int, - **kwargs, + **kwargs: Any, ) -> None: assert rank in (1, 2) super().__init__(**kwargs) @@ -2455,9 +2634,9 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] if not isinstance(ptr, variables.DataPtrVariable): @@ -2507,13 +2686,13 @@ class CreateTMADescriptorStableVariable(VariableTracker): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] return TMADescriptorStableVariable( - tensor=tensor, - block_shape=block_shape, + tensor=tensor, # type: ignore[arg-type] + block_shape=block_shape, # type: ignore[arg-type] ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index bdb37da3ccce1..2b603a7af22a2 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -590,7 +590,7 @@ def _next() -> VariableTracker: else: res = self.fn.call_function(tx, [item], {}) pred_res = variables.UserFunctionVariable( - polyfills.predicate + polyfills.predicate # type: ignore[arg-type] ).call_function(tx, [res], {}) if pred_res.as_python_constant(): return item diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 30c1b8c2cf186..e17d27c16dca6 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -472,7 +472,12 @@ def call_function( ) elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] name_to_arg_map = bind_args_cached( - self.value, tx, self.source, args, kwargs + # pyrefly: ignore[bad-argument-type] + self.value, + tx, + self.source, + args, + kwargs, ) backends = name_to_arg_map["backends"].as_python_constant() set_priority = name_to_arg_map["set_priority"].as_python_constant() @@ -1429,7 +1434,7 @@ def call_function( packed_input_vt = TupleVariable.build( tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) ) - out_vt = variables.UserFunctionVariable(tree_flatten).call_function( + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] tx, [packed_input_vt], {} ) assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 From 669cf21a6b13f956492088e9267107586f196849 Mon Sep 17 00:00:00 2001 From: ankushwahaRH Date: Fri, 7 Nov 2025 00:53:54 +0000 Subject: [PATCH 164/651] Added Validation for batch_norm eps value (#166756) Fixes #166405. I've fixed this by adding epsilon validation in ```torch.nn.functional.batch_norm``` to reject non-positive values before they cause undefined behavior. Also added a test case ```test_batchnorm_invalid_eps``` to verify the fix works correctly. While working on this, I noticed that ```layer_norm```, ```group_norm```, and ```instance_norm``` also don't validate epsilon and could have the same issue. Should I add validation for those in this PR as well? Pull Request resolved: https://github.com/pytorch/pytorch/pull/166756 Approved by: https://github.com/mikaylagawarecki --- test/test_mps.py | 6 ++-- torch/nn/functional.py | 3 ++ torch/testing/_internal/common_modules.py | 35 +++++++++++++++++++++++ 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/test_mps.py b/test/test_mps.py index cb0db4d96d334..867429432cfe0 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1762,13 +1762,13 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= continue # Running stats must be tracked in eval mode if (track_running_stats): - helper(shape, eps=0, momentum=1, channels_last=channels_last, + helper(shape, eps=1e-5, momentum=1, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, eps=1e-05, momentum=0.1, wts=False, training=False, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) - helper(shape, eps=0, momentum=1.0, wts=False, training=False, channels_last=channels_last, + helper(shape, eps=1e-5, momentum=1.0, wts=False, training=False, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, eps=1, momentum=1, wts=True, training=False, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) @@ -1776,7 +1776,7 @@ def helper(shape, eps=1, momentum=0.1, wts=False, training=False, channels_last= track_running_stats=track_running_stats, test_module=test_module) helper(shape, eps=1e-05, momentum=0.1, wts=False, training=True, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) - helper(shape, eps=0, momentum=1.0, wts=False, training=True, channels_last=channels_last, + helper(shape, eps=1e-5, momentum=1.0, wts=False, training=True, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) helper(shape, eps=1, momentum=1, wts=True, training=True, channels_last=channels_last, track_running_stats=track_running_stats, test_module=test_module) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index bc1e873c428fb..f92d1fa0fa2dd 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2838,6 +2838,9 @@ def batch_norm( # pyrefly: ignore [bad-argument-type] _verify_batch_size(input.size()) + if eps <= 0.0: + raise ValueError(f"batch_norm eps must be positive, but got {eps}") + return torch.batch_norm( input, weight, diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 120a76eb5ef32..9571cc1209ed6 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -961,6 +961,38 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad desc='zero_batch')] +def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + if module_info.module_cls == torch.nn.BatchNorm1d: + input_shape = (2, 10) + elif module_info.module_cls == torch.nn.BatchNorm2d: + input_shape = (2, 10, 5, 5) + else: + input_shape = (2, 10, 4, 4, 4) + + return [ + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, eps=-1.0), + forward_input=FunctionInput(make_input(input_shape)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex="eps must be positive" + ), + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(10, eps=0.0), + forward_input=FunctionInput(make_input(input_shape)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=ValueError, + error_regex="eps must be positive" + ), + ] + + def module_inputs_torch_nn_ConvNd(module_info, device, dtype, requires_grad, training, **kwargs): N = kwargs['N'] lazy = kwargs.get('lazy', False) @@ -3430,6 +3462,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ModuleInfo(torch.nn.BatchNorm1d, train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm1d, + module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d, skips=( # tracking here rather than in the list in test_aotdispatch.py as eval mode passes # RuntimeError: tried to get Double out of SymInt @@ -3448,6 +3481,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ModuleInfo(torch.nn.BatchNorm2d, train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm2d, + module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d, skips=( # See https://github.com/pytorch/pytorch/issues/134580 DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format', active_if=operator.itemgetter('training')), @@ -3468,6 +3502,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ModuleInfo(torch.nn.BatchNorm3d, train_and_eval_differ=True, module_inputs_func=module_inputs_torch_nn_BatchNorm3d, + module_error_inputs_func=module_error_inputs_torch_nn_BatchNorm1d_2d_3d, skips=( # not supported on MPS backend DecorateInfo(skipMPS), From cd6d06a22ba3a90f6592da9ffecead7dae54b5f6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 01:06:34 +0000 Subject: [PATCH 165/651] Revert "[BE][Typing][Dynamo] Type torch/_dynamo/variables/functions.py (#167103)" This reverts commit 9a86ef763201e27f031469f0866c893707e9cf38. Reverted https://github.com/pytorch/pytorch/pull/167103 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/167103#issuecomment-3500023910)) --- torch/_dynamo/variables/builtin.py | 2 +- torch/_dynamo/variables/functions.py | 773 ++++++++++----------------- torch/_dynamo/variables/iter.py | 2 +- torch/_dynamo/variables/torch.py | 9 +- 4 files changed, 301 insertions(+), 485 deletions(-) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 579cf7bfffc3d..0f198377605ec 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1991,7 +1991,7 @@ def call_iter( # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() # with an integer argument starting at 0, until __getitem__ raises IndexError ret = variables.UserFunctionVariable( - polyfills.builtins.iter_ # type: ignore[arg-type] + polyfills.builtins.iter_ ).call_function(tx, [obj, *args], {}) if args: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index c4865bfdedbfc..0752a413fce6e 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + """ Function-related variable tracking classes for Dynamo's symbolic execution. @@ -30,14 +32,13 @@ import traceback import types from collections.abc import Callable, Sequence -from types import CellType, FunctionType +from types import FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from weakref import WeakKeyDictionary import torch from torch._dynamo.exc import get_stack_above_dynamo -from torch._guards import Source from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -86,32 +87,25 @@ try: from torch.distributed.fsdp._fully_shard import _fsdp_param_group except ModuleNotFoundError: - _fsdp_param_group = None # type: ignore[assignment] + _fsdp_param_group = None if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen - from torch._dynamo.symbolic_convert import ( - InstructionTranslator, - InstructionTranslatorBase, - ) - from torch._dynamo.variables.ctx_manager import ContextWrappingVariable + from torch._dynamo.symbolic_convert import InstructionTranslator from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, TritonKernelType, ) - from .lists import BaseListVariable, ListVariable - from .tensor import TensorVariable - -_F = TypeVar("_F", bound=Callable[..., Any]) +_F = TypeVar("_F", bound=Callable) CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 # Module-level cache keyed by the function object -_spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() +_spec_cache = WeakKeyDictionary() class FunctionSpec: @@ -133,7 +127,7 @@ def __init__(self, func: FunctionType): off += 1 if self.varargs_name else 0 self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None - def update_defaults(self, func: FunctionType) -> None: + def update_defaults(self, func: FunctionType): # Defaults can change from function call to function call. So re-update # them on every call. self.defaults = func.__defaults__ or () @@ -153,13 +147,7 @@ def _get_spec(func: FunctionType) -> FunctionSpec: return spec -def bind_args_cached( - func: FunctionType, - tx: "InstructionTranslator", - fn_source: Optional[Source], - args: Sequence[Any], - kwargs: dict[str, Any], -) -> dict[str, VariableTracker]: +def bind_args_cached(func, tx, fn_source, args, kwargs): spec = _get_spec(func) spec.update_defaults(func) ba = {} @@ -252,9 +240,7 @@ def bind_args_cached( return ba -def wrap_bound_arg( - tx: "InstructionTranslator", val: Any, source: Optional[Source] = None -) -> VariableTracker: +def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): return val @@ -266,18 +252,14 @@ def wrap_bound_arg( return variables.LazyVariableTracker.create(val, source) -def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None: +def wrap_args_kwargs(tx: "InstructionTranslator", result): for k, v in list(result.items()): if isinstance(v, (tuple, dict)): # args/kwargs result[k] = wrap_bound_arg(tx, v) -def init_cellvars( - parent: "InstructionTranslator", - result: dict[str, VariableTracker], - code: types.CodeType, -) -> None: +def init_cellvars(parent, result: dict[str, VariableTracker], code): """ Update `result` to add mapping from local name to new cells created directly by `code`, or update SideEffects in `parent` if the a local cell is @@ -295,14 +277,8 @@ def init_cellvars( def _create_nested_fn( - code: types.CodeType, - f_globals: dict[str, Any], - name: str, - defaults: Optional[tuple[object, ...]], - closure: Optional[tuple[CellType]], - kwdefaults: Optional[dict[str, Any]], - annotations: Optional[dict[str, Any]], -) -> types.FunctionType: + code, f_globals, name, defaults, closure, kwdefaults, annotations +): from types import FunctionType func = FunctionType(code, f_globals, name, defaults, closure) @@ -315,7 +291,7 @@ def _create_nested_fn( # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) - func.__annotations__ = annotations # type: ignore[assignment] + func.__annotations__ = annotations return func @@ -331,9 +307,7 @@ def _create_nested_fn( } -def fn_var_getattr( - tx: "InstructionTranslator", fn: object, source: Optional[Source], name: str -) -> VariableTracker: +def fn_var_getattr(tx, fn, source, name): source = source and AttrSource(source, name) if source and name == "__annotations__": @@ -342,7 +316,6 @@ def fn_var_getattr( # graph is even rarer. So skip guards. source = SkipGuardSource(source) - subobj = None try: subobj = inspect.getattr_static(fn, name) except AttributeError: @@ -359,19 +332,19 @@ def fn_var_getattr( class BaseUserFunctionVariable(VariableTracker): - def get_filename(self) -> str: - return self.get_code().co_filename # type: ignore[attr-defined] + def get_filename(self): + return self.get_code().co_filename - def get_name(self) -> str: - return self.get_code().co_name # type: ignore[attr-defined] + def get_name(self): + return self.get_code().co_name def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -379,16 +352,16 @@ def call_obj_hasattr( result = False try: - result = hasattr(self.get_function(), name) # type: ignore[attr-defined] + result = hasattr(self.get_function(), name) except NotImplementedError: if name == "__name__" and isinstance(self, NestedUserFunctionVariable): result = True return variables.ConstantVariable.create(result) - def inspect_parameter_names(self) -> list[str]: - return list(inspect.signature(self.get_function()).parameters) # type: ignore[attr-defined] + def inspect_parameter_names(self): + return list(inspect.signature(self.get_function()).parameters) - def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]: + def closure_vars(self, tx): return {} @@ -402,16 +375,11 @@ class UserFunctionVariable(BaseUserFunctionVariable): } @classmethod - def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": + def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__( - self, - fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg] - is_constant: bool = False, - **kwargs: Any, - ) -> None: + def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) if getattr(fn, "_dynamo_marked_constant", False): # This method should be treated as a constant for the purposes of compilation @@ -435,45 +403,40 @@ def __init__( # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) - self.fn = fn + self.fn: types.FunctionType = fn - def as_python_constant(self) -> Any: + def as_python_constant(self): if istype(self, UserFunctionVariable): return self.fn # subclasses (such as methods) usually aren't a constant return super().as_python_constant() - def self_args(self) -> list[VariableTracker]: + def self_args(self): return [] - def get_function(self) -> types.FunctionType: + def get_function(self): return self.fn - def get_code(self) -> types.CodeType: + def get_code(self): return self.fn.__code__ - def python_type(self) -> type: + def python_type(self): return types.FunctionType - def has_self(self) -> bool: + def has_self(self): return getattr(self.fn, "__self__", None) is not None - def get_globals(self) -> dict[str, Any]: + def get_globals(self): return self.fn.__globals__ - def get_source(self) -> Source: + def get_source(self): source = self.source if source and isinstance(self, variables.UserMethodVariable): - source = self.source_fn # type: ignore[assignment] - return source # type: ignore[return-value] + source = self.source_fn + return source - def bind_args( - self, - parent: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> dict[str, VariableTracker]: + def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to this function, create new bindings for initial locals. @@ -487,7 +450,7 @@ def bind_args( root_tx = parent.output.root_tx source = self.get_source() - result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type] + result = bind_args_cached(fn, root_tx, source, args, kwargs) init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -528,7 +491,7 @@ def bind_args( return result - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) source = self.get_source() @@ -543,9 +506,9 @@ def call_obj_hasattr( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": # Handle patch_dynamo_config call if self.fn is torch._dynamo.patch_dynamo_config: try: @@ -585,7 +548,7 @@ def call_function( msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" unimplemented_v2( gb_type="TypeError from user code", - context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] + context=f"call_function({self.value}, {args}, {kwargs})", explanation=msg, hints=[ *graph_break_hints.USER_ERROR, @@ -604,7 +567,7 @@ def call_function( "`torch.compile` region", ], ) - # pyrefly: ignore[missing-attribute] + fn = fn_var.fn return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) @@ -630,7 +593,7 @@ def call_function( try: from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState except Exception: - FSDPState = None # type: ignore[assignment, misc] + FSDPState = None if FSDPState is not None and self.fn in [ FSDPState._pre_forward, FSDPState._post_forward, @@ -641,15 +604,13 @@ def call_function( class BuiltinMethodVariable(BaseUserFunctionVariable): - def __init__( - self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any - ) -> None: + def __init__(self, fn, is_constant=False, **kwargs) -> None: super().__init__(**kwargs) assert isinstance(fn, types.BuiltinMethodType) self.fn = fn @staticmethod - def is_supported_builtin_method(obj: Any) -> bool: + def is_supported_builtin_method(obj): method_self = obj.__self__ method_name = obj.__name__ @@ -662,9 +623,9 @@ def is_supported_builtin_method(obj: Any) -> bool: def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": method_self = self.fn.__self__ name = self.fn.__name__ obj_source = self.source and AttrSource(self.source, "__self__") @@ -676,39 +637,39 @@ class LocalGeneratorObjectVariable(VariableTracker): def __init__( self, code: types.CodeType, - f_globals: dict[str, Any], + f_globals, inline_tracer: Optional["InstructionTranslator"], - **kwargs: Any, - ) -> None: + **kwargs, + ): super().__init__(**kwargs) self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer - def get_code(self) -> types.CodeType: + def get_code(self): return self.code - def get_filename(self) -> str: + def get_filename(self): return self.get_code().co_filename - def get_name(self) -> str: + def get_name(self): return self.get_code().co_name - def get_function(self) -> Never: + def get_function(self): raise NotImplementedError - def has_self(self) -> bool: + def has_self(self): return False - def __name__(self) -> str: + def __name__(self): return self.get_name() - def __str__(self) -> str: + def __str__(self): return f"{self.__class__.__name__}({self.get_name()})" __repr__ = __str__ - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen"): from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -727,30 +688,25 @@ def reconstruct(self, codegen: "PyCodegen") -> None: self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) - def bind_args( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> dict[str, VariableTracker]: - return self.vt.bind_args(tx, args, kwargs) # type: ignore[attr-defined] + def bind_args(self, tx, args, kwargs): + return self.fn.bind_args(tx, args, kwargs) - def get_globals(self) -> dict[str, Any]: + def get_globals(self): return self.f_globals - def python_type(self) -> type: + def python_type(self): return types.GeneratorType - def _get_inline_tracer(self, tx: "InstructionTranslator") -> Any: + def _get_inline_tracer(self, tx): from torch._dynamo.symbolic_convert import InliningInstructionTranslator if self.inline_tracer is None: - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( # type: ignore[assignment] + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( tx, self, [], {} ) return self.inline_tracer - def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + def next_variable(self, tx): tracer = self._get_inline_tracer(tx) if self._is_generator_exhausted(): @@ -771,29 +727,23 @@ def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: torch._dynamo.eval_frame.skip_code(self.get_code()) raise SkipFrame from e - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + def call_obj_hasattr(self, tx, name): if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) - def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + def has_unpack_var_sequence(self, tx): return False - def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: + def has_force_unpack_var_sequence(self, tx) -> builtins.bool: return True - def force_unpack_var_sequence( - self, tx: "InstructionTranslator" - ) -> list[VariableTracker]: - result: list[VariableTracker] = [] + def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: + result = [] self.force_apply_to_var_sequence(tx, result.append) return result - def force_apply_to_var_sequence( - self, tx: "InstructionTranslator", fn: Callable[[VariableTracker], Any] - ) -> None: + def force_apply_to_var_sequence(self, tx, fn) -> None: while True: try: fn(self.next_variable(tx)) @@ -801,9 +751,7 @@ def force_apply_to_var_sequence( handle_observed_exception(tx) break - def _setup_exception( - self, tx: "InstructionTranslator", exc: VariableTracker - ) -> None: + def _setup_exception(self, tx, exc): tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) @@ -812,19 +760,19 @@ def _setup_exception( # exception is raised again. tracer.exception_handler(e) - def _is_generator_just_started(self) -> bool: + def _is_generator_just_started(self): return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 - def _is_generator_exhausted(self) -> bool: + def _is_generator_exhausted(self): return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( self, tx: "InstructionTranslator", name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if name == "__next__": return self.next_variable(tx) elif name == "__iter__": @@ -1004,7 +952,7 @@ def call_method( raise_observed_exception(RuntimeError, tracer) return retval - return super().call_method(tx, name, args, kwargs) + super().call_method(tx, name, args, kwargs) class ContextlibContextManagerLocalGeneratorObjectVariable( @@ -1032,24 +980,19 @@ def __init__( self, vt: VariableTracker, *, - generator_cls: type = LocalGeneratorObjectVariable, - **kwargs: Any, - ) -> None: + generator_cls=LocalGeneratorObjectVariable, + **kwargs, + ): super().__init__(**kwargs) self.vt = vt self.generator_cls = generator_cls - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name): if name in self.__class__.__dict__.keys(): return getattr(self, name) return getattr(self.vt, name) - def _build_inline_tracer( - self, - tx: "InstructionTranslatorBase", - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> "InstructionTranslatorBase": + def _build_inline_tracer(self, tx, args, kwargs): from torch._dynamo.symbolic_convert import InliningInstructionTranslator return InliningInstructionTranslator.build_inline_tracer( @@ -1062,13 +1005,13 @@ def _build_inline_tracer( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not is_generator(self.vt.get_code()): unimplemented_v2( gb_type="non-generator contextlib.contextmanager", - context=str(self.vt.get_code()), # type: ignore[attr-defined] + context=str(self.vt.get_code()), explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" ", i.e. does not use `yield`", hints=[ @@ -1077,15 +1020,15 @@ def call_function( ], ) - inline_tracer = self._build_inline_tracer(tx, list(args), kwargs) - code = self.vt.get_code() # type: ignore[attr-defined] - f_globals = self.vt.get_globals() # type: ignore[attr-defined] + inline_tracer = self._build_inline_tracer(tx, args, kwargs) + code = self.vt.get_code() + f_globals = self.vt.get_globals() # calling a generator returns a generator object return self.generator_cls( code, f_globals, - inline_tracer, # type: ignore[arg-type] + inline_tracer, source=self.source, ) @@ -1099,19 +1042,14 @@ class FunctionDecoratedByContextlibContextManagerVariable( This is only used when the function is annotated with @contextlib.contextmanager """ - def __init__(self, vt: VariableTracker, **kwargs: Any): + def __init__(self, vt, **kwargs): super().__init__( vt, generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, **kwargs, ) - def _build_inline_tracer( - self, - tx: "InstructionTranslatorBase", - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> "InstructionTranslatorBase": + def _build_inline_tracer(self, tx, args, kwargs): # NOTE: This only exists to not break support for context manager when # config.enable_faithful_generator_behavior = False and # config.enable_trace_contextlib = True. In case the former is false, @@ -1128,14 +1066,8 @@ def _build_inline_tracer( class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" - def __init__( - self, - fn: Callable[..., Any], - obj: VariableTracker, - source_fn: Optional[Callable[..., Any]] = None, - **kwargs: Any, - ) -> None: - super().__init__(fn=fn, **kwargs) # type: ignore[arg-type] + def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: + super().__init__(fn=fn, **kwargs) self.obj = obj self.source_fn = source_fn # Note on source and source_fn @@ -1151,24 +1083,24 @@ def __init__( # operates on the unbound function, most guards should target # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: - self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type] + self.source_fn = AttrSource(kwargs.get("source"), "__func__") def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" - def self_args(self) -> list[VariableTracker]: + def self_args(self): return [self.obj] - def python_type(self) -> type[types.MethodType]: + def python_type(self): return types.MethodType def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - # NOTE this is to handle methods annotated by `nonstrict_trace`. + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually # a `nonstrict_trace`-ed function will be wrapped by # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, # but in the case of method, we manually wrap it with `UserMethodVariable` @@ -1209,41 +1141,36 @@ def call_function( or self.is_constant ): return self.obj.call_method( - tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant + tx, self.fn.__name__, args, kwargs, constant=self.is_constant ) elif ( _fsdp_param_group is not None - and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined] + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state ): return variables.TorchCtxManagerClassVariable(self.fn).call_function( tx, (self.obj, *args), kwargs ) if self.is_constant: - fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined] + fn = getattr(self.obj.value, self.fn.__name__) return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) - def inspect_parameter_names(self) -> list[str]: + def inspect_parameter_names(self): return super().inspect_parameter_names()[1:] - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str): if name == "__self__": return self.obj if name == "__func__": # We might have a better way to access the function object, this # information is stored in self.source_fn, use that to construct the # variable tracker. - return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] + return VariableTracker.build(tx, self.fn, self.source_fn) return super().var_getattr(tx, name) class WrappedUserMethodVariable(UserMethodVariable): - def __init__( - self, - wrapped: UserMethodVariable, - context: "ContextWrappingVariable", - **kwargs: Any, - ) -> None: + def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) kwargs.pop("obj", None) super().__init__(wrapped.fn, wrapped.obj, **kwargs) @@ -1253,27 +1180,22 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) class WrappedUserFunctionVariable(UserFunctionVariable): - def __init__( - self, - wrapped: UserFunctionVariable, - context: "ContextWrappingVariable", - **kwargs: Any, - ) -> None: + def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped @@ -1282,28 +1204,22 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) -def invoke_and_store_as_constant( - tx: "InstructionTranslator", - fn: Callable[..., Any], - name: str, - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], -) -> VariableTracker: - def convert(x: VariableTracker) -> Any: +def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): + def convert(x): if isinstance(x, variables.TensorVariable): return x.get_real_value() return x.as_python_constant() @@ -1326,17 +1242,17 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def __init__( self, - fn_name: VariableTracker, - code: VariableTracker, - f_globals: dict[str, Any], - defaults: Optional[VariableTracker], - kwdefaults: Optional[VariableTracker], - annotations: Optional[VariableTracker], - closure: Optional[VariableTracker], + fn_name, + code, + f_globals, + defaults, + kwdefaults, + annotations, + closure, # This is present when this function is created by # `functools.wrap(wrapped_fn)(this_fn)`. - wrapped_fn: Optional[VariableTracker] = None, - **kwargs: Any, + wrapped_fn=None, + **kwargs, ) -> None: if kwargs.get("mutation_type") is None: kwargs.update(mutation_type=AttributeMutationNew()) @@ -1353,16 +1269,16 @@ def __init__( self.closure = closure self.wrapped_fn: Optional[VariableTracker] = wrapped_fn - def self_args(self) -> list[VariableTracker]: + def self_args(self): return [] - def get_code(self) -> types.CodeType: + def get_code(self): return self.code.as_python_constant() - def python_type(self) -> type: + def python_type(self): return types.FunctionType - def get_function(self) -> types.FunctionType: + def get_function(self): if self.closure: raise NotImplementedError func = types.FunctionType( @@ -1391,25 +1307,19 @@ def call_setattr( tx: "InstructionTranslator", name_var: VariableTracker, val: VariableTracker, - ) -> VariableTracker: - tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined] + ): + tx.output.side_effects.store_attr(self, name_var.value, val) return ConstantVariable(None) - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + def call_method(self, tx, name, args, kwargs): if name == "__setattr__": return self.call_setattr(tx, *args) - return super().call_method(tx, name, list(args), kwargs) + return super().call_method(tx, name, args, kwargs) - def has_closure(self) -> bool: + def has_closure(self): return self.closure is not None - def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + def const_getattr(self, tx, name): if name == "__name__": return self.get_name() if name == "__code__": @@ -1419,57 +1329,50 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: return d.as_python_constant() if d else None return super().const_getattr(tx, name) - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + def call_obj_hasattr(self, tx: "InstructionTranslator", name): if name == "__code__": return variables.ConstantVariable.create(hasattr(self, "code")) if name == "__defaults__": return variables.ConstantVariable.create(hasattr(self, "defaults")) return super().call_obj_hasattr(tx, name) - def has_self(self) -> bool: + def has_self(self): return False - def get_globals(self) -> dict[str, Any]: + def get_globals(self): return self.f_globals - def bind_args( - self, - parent: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> dict[str, VariableTracker]: + def bind_args(self, parent, args, kwargs): code = self.get_code() func = types.FunctionType( code, self.f_globals, self.fn_name.as_python_constant(), - tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined] + tuple(self.defaults.items) if self.defaults else None, tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), ) if self.kwdefaults: - func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined] + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() bound = inspect.signature(func).bind(*args, **kwargs) bound.apply_defaults() result = dict(bound.arguments.items()) - wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type] + wrap_args_kwargs(parent.output.root_tx, result) init_cellvars(parent, result, code) for idx, name in enumerate(code.co_freevars): assert name not in result - cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr] + cell = self.closure.items[idx] result[name] = cell return result - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) codegen(self.code) codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) - codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined] + codegen(ConstantVariable.create(self.code.value.co_name)) if self.defaults: codegen(self.defaults) @@ -1523,12 +1426,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): - def __init__( - self, - wrapped: Any, - context: "ContextWrappingVariable", - **kwargs: Any, - ) -> None: + def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn_name", None) kwargs.pop("code", None) kwargs.pop("f_globals", None) @@ -1553,16 +1451,16 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1574,16 +1472,16 @@ class SkipFunctionVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, value: Any, reason: Optional[str] = None, **kwargs: Any) -> None: + def __init__(self, value, reason=None, **kwargs) -> None: super().__init__(**kwargs) self.value = value self.reason = reason - def as_python_constant(self) -> Any: + def as_python_constant(self): return self.value @classmethod - def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": + def create_with_source(cls, value, source): # Use closure match guard (i.e. guard on __code__ object instead of # function id) to avoid guarding on nested functions. if inspect.getattr_static(value, "_torchdynamo_disable", False): @@ -1612,9 +1510,9 @@ def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) unimplemented_v2( @@ -1627,7 +1525,7 @@ def call_function( ], ) elif self.value is torch._dynamo.graph_break: - graph_break_msg = kwargs.get("msg") + graph_break_msg = kwargs.get("msg", None) if graph_break_msg: graph_break_msg = graph_break_msg.as_python_constant() unimplemented_v2( @@ -1639,7 +1537,7 @@ def call_function( ], ) elif self.value is torch._dynamo.skip_frame: - skip_frame_msg = kwargs.get("msg") + skip_frame_msg = kwargs.get("msg", None) if skip_frame_msg: skip_frame_msg = skip_frame_msg.as_python_constant() raise SkipFrame( @@ -1731,12 +1629,10 @@ def call_function( hints=hints, ) - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> VariableTracker: + def call_obj_hasattr(self, tx: "InstructionTranslator", name): return variables.ConstantVariable.create(hasattr(self.value, name)) - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str): if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) @@ -1744,31 +1640,26 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker class WrappedSkipFunctionVariable(SkipFunctionVariable): - def __init__( - self, - wrapped: VariableTracker, - context: "ContextWrappingVariable", - **kwargs: Any, - ) -> None: + def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("value", None) kwargs.pop("reason", None) - super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) # type: ignore[attr-defined] + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) self.wrapped = wrapped self.context = context def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1781,12 +1672,12 @@ class WrapperUserFunctionVariable(VariableTracker): __script_if_tracing_wrapper have the original attr at "__original_fn". """ - def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: + def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: super().__init__(**kwargs) self.wrapper_obj = wrapper_obj self.attr_to_trace = attr_to_trace - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name): if name == self.attr_to_trace: val = getattr(self.wrapper_obj, self.attr_to_trace) source = self.source and AttrSource(self.source, name) @@ -1794,15 +1685,15 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return super().var_getattr(tx, name) - def self_args(self) -> list[VariableTracker]: + def self_args(self): return [] def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if hasattr(self.wrapper_obj, "cache_info"): target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) module_name = getattr(target_fn, "__module__", "") or "" @@ -1828,9 +1719,9 @@ def call_function( user_stack_trace += str(user_stack_formatted) dynamo_logger.debug(user_stack_trace) - all_args = self.self_args() + list(args) + all_args = self.self_args() + args return variables.UserFunctionVariable( - polyfills.getattr_and_trace # type: ignore[arg-type] + polyfills.getattr_and_trace ).call_function( tx, [self, variables.ConstantVariable(self.attr_to_trace), *all_args], @@ -1845,21 +1736,15 @@ class WrapperUserMethodVariable(WrapperUserFunctionVariable): WrapperUserFunctionVariable in `call_function` method. """ - def __init__( - self, - wrapper_obj: Any, - attr_to_trace: str, - self_obj: VariableTracker, - **kwargs: Any, - ) -> None: + def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None: super().__init__(wrapper_obj, attr_to_trace, **kwargs) self.obj = self_obj - def self_args(self) -> list[VariableTracker]: + def self_args(self): return [self.obj] -def _traceable_collective_remaps() -> dict[Any, Any]: +def _traceable_collective_remaps(): # We can't rely on importing from distributed, since it's not always built if torch.distributed.is_available(): from torch.distributed._functional_collectives import ( @@ -1870,9 +1755,7 @@ def _traceable_collective_remaps() -> dict[Any, Any]: return {} -def _traceable_collectives_source( - tx: "InstructionTranslator", fn: Callable[..., Any] -) -> AttrSource: +def _traceable_collectives_source(tx: "InstructionTranslator", fn): assert torch.distributed.is_available(), "Illegal invocation." assert fn in _traceable_collective_remaps().values() @@ -1892,24 +1775,13 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable): than status-quo as we currently graph-break on all distributed.* collectives. """ - def __init__( - self, - fn: Callable[..., Any], - *, - replacement_var: UserFunctionVariable, - **kwargs: Any, - ) -> None: - super().__init__(fn, **kwargs) # type: ignore[arg-type] + def __init__(self, fn, *, replacement_var, **kwargs) -> None: + super().__init__(fn, **kwargs) assert isinstance(replacement_var, UserFunctionVariable) self.replacement_var = replacement_var @staticmethod - def create( - tx: "InstructionTranslator", - old_fn: Callable[..., Any], - source: Source, - **options: Any, - ) -> "CollectiveFunctionRewriteVariable": + def create(tx: "InstructionTranslator", old_fn, source, **options): new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) return CollectiveFunctionRewriteVariable( old_fn, @@ -1919,24 +1791,22 @@ def create( ) @staticmethod - def can_rewrite(variable: Any) -> bool: + def can_rewrite(variable): return ( inspect.isfunction(variable) and variable in _traceable_collective_remaps() ) @staticmethod - def rewrite( - tx: "InstructionTranslator", fn: Callable[..., Any] - ) -> tuple[Any, AttrSource]: + def rewrite(tx: "InstructionTranslator", fn): new_fn = _traceable_collective_remaps()[fn] return new_fn, _traceable_collectives_source(tx, new_fn) def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": # call_function must check any unsupported arguments and graph-break. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, # since that's the contract for putting a mapping in `traceable_collective_remaps` @@ -1966,7 +1836,7 @@ def call_function( ): reduce_op_var = kwargs.get("op") reduce_op = ( - reduce_op_var.value # type: ignore[attr-defined] + reduce_op_var.value if reduce_op_var is not None else signature.parameters["op"].default ) @@ -1982,12 +1852,12 @@ class FunctoolsWrapsVariable(UserFunctionVariable): def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if not kwargs and len(args) == 1: - def wraps(fn: Any) -> VariableTracker: + def wraps(fn): if isinstance(fn, variables.NestedUserFunctionVariable): return fn.clone(wrapped_fn=args[0]) unimplemented_v2( @@ -2005,15 +1875,15 @@ def wraps(fn: Any) -> VariableTracker: class CollectionsNamedTupleFunction(UserFunctionVariable): - def as_python_constant(self) -> Any: + def as_python_constant(self): return self.fn def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": constant_args = check_constant_args(args, kwargs) if constant_args: try: @@ -2028,9 +1898,7 @@ def call_function( args=list(map(ConstantVariable.create, exc.args)), ) return variables.UserDefinedClassVariable( - # pyrefly: ignore[unbound-name] - value, - mutation_type=ValueMutationNew(), + value, mutation_type=ValueMutationNew() ) unimplemented_v2( gb_type="namedtuple construction", @@ -2043,13 +1911,7 @@ def call_function( class FunctoolsPartialVariable(VariableTracker): - def __init__( - self, - func: VariableTracker, - args: Sequence[VariableTracker], - keywords: dict[str, VariableTracker], - **kwargs: Any, - ) -> None: + def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: super().__init__(**kwargs) self.func = func assert isinstance(args, list) @@ -2060,10 +1922,10 @@ def __init__( # on it is sufficient for the tracing purposes. self.fake_value = functools.partial(identity) - def python_type(self) -> type: + def python_type(self): return functools.partial - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -2078,16 +1940,16 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) ) - def get_function(self) -> Any: + def get_function(self): return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - merged_args = self.args + list(args) + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + merged_args = self.args + args merged_kwargs = {**self.keywords, **kwargs} return self.func.call_function(tx, merged_args, merged_kwargs) @@ -2099,7 +1961,7 @@ def call_obj_hasattr( hasattr(functools.partial(identity), name) ) - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str): source = self.source and AttrSource(self.source, name) # Handle __slots__ if name == "func": @@ -2113,14 +1975,14 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return variables.GetAttrVariable(self, name) raise_observed_exception(AttributeError, tx) - def as_python_constant(self) -> Any: + def as_python_constant(self): return functools.partial( self.func.as_python_constant(), *[arg.as_python_constant() for arg in self.args], **{k: v.as_python_constant() for k, v in self.keywords.items()}, ) - def guard_as_python_constant(self) -> Any: + def guard_as_python_constant(self): """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" return functools.partial( self.func.guard_as_python_constant(), @@ -2143,20 +2005,16 @@ def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} @classmethod - def create_with_source( - cls, value: Any, source: Source - ) -> "PolyfilledFunctionVariable": + def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__(self, fn: _F, **kwargs: Any) -> None: + def __init__(self, fn: _F, **kwargs) -> None: super().__init__(**kwargs) - # pyrefly: ignore[invalid-type-var] self.fn: _F = fn handler = self._get_polyfill_handlers().get(fn, fn) - traceable_fn = None assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" for candidate_attr in ( "__torch_dynamo_polyfill__", # registered polyfill @@ -2171,29 +2029,28 @@ def __init__(self, fn: _F, **kwargs: Any) -> None: raise RuntimeError( f"Polyfill handler {handler} does not have a traceable function" ) - # pyrefly: ignore[invalid-type-var] - self.wrapped_fn = handler - # pyrefly: ignore[invalid-type-var] + + self.wrapped_fn: _F = handler self.traceable_fn: _F = traceable_fn @property - def polyfill_fn(self) -> Callable[..., Any]: + def polyfill_fn(self) -> _F: return self.traceable_fn - def can_constant_fold_through(self) -> bool: + def can_constant_fold_through(self): return getattr( self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False ) - def get_function(self) -> Any: + def get_function(self): return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -2230,7 +2087,7 @@ def call_function( ( x.value if isinstance(x, variables.ConstantVariable) - else x.sym_num # type: ignore[attr-defined] + else x.sym_num ) for x in args[0].items ] @@ -2242,11 +2099,11 @@ def call_function( def call_method( self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if name == "__call__": return self.call_function(tx, args, kwargs) @@ -2256,33 +2113,27 @@ def call_method( options = {} if self.source: options["source"] = AttrSource(self.source, name) - # pyrefly: ignore[bad-specialization] polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) return polyfilled_method_variable.call_function(tx, args, kwargs) - def as_python_constant(self) -> Any: + def as_python_constant(self): return self.fn class TracebackVariable(VariableTracker): # We don't track traceback. A call to any function in this module is a no-op - def call_function( # type: ignore[empty-body] - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: ... + def call_function(self, tx, args, kwargs): ... class SysFunctionVariable(VariableTracker): - def __init__(self, value: Any, **kwargs: Any) -> None: + def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value - def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": + def exc_info(self, tx): if len(tx.exn_vt_stack): exn = tx.exn_vt_stack[-1] - typ = exn.exc_type # type: ignore[union-attr] + typ = exn.exc_type tb = None items = [ VariableTracker.build(tx, typ), @@ -2295,17 +2146,12 @@ def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": variables.ConstantVariable(None), variables.ConstantVariable(None), ] - return variables.TupleVariable(items) # type: ignore[arg-type] + return variables.TupleVariable(items) - def exception(self, tx: "InstructionTranslator") -> VariableTracker: + def exception(self, tx): return self.exc_info(tx).items[1] - def call_function( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + def call_function(self, tx, args, kwargs): if self.value is sys.exc_info: return self.exc_info(tx) assert self.value is sys.exception @@ -2324,15 +2170,15 @@ class DynamoTritonHOPifier(TritonHOPifier): def raise_unsupported(self, msg: str) -> Never: raise Unsupported(msg) - def is_callable(self, maybe_callable: VariableTracker) -> bool: + def is_callable(self, maybe_callable: Any) -> bool: return isinstance( maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) ) - def get_value(self, val: VariableTracker) -> Any: - return val.value # type: ignore[attr-defined] + def get_value(self, val: Any) -> Any: + return val.value - def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]: + def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: from .lists import BaseListVariable if isinstance(grid, BaseListVariable): @@ -2347,35 +2193,20 @@ def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, .. ], ) - def call_grid( - self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator" - ) -> Any: - meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()} - grid = grid.call_function(tx, [meta_var], {}) + def call_grid(self, grid, meta, tx): + meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta], {}) return grid # We use this function to wrap call_prune_configs - def call_user_defined_fn( - self, - user_fn: Callable[..., Any], - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - tx: Optional["InstructionTranslator"], - variable: Any, - ) -> VariableTracker: + def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable): from .builder import SourcelessBuilder - wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type] + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) result = wrapped_user_function.call_function(tx, args, kwargs) return result - def wrap_user_defined_obj( - self, - user_obj: Any, - tx: Optional["InstructionTranslator"], - variable: Any, - name: str, - ) -> VariableTracker: + def wrap_user_defined_obj(self, user_obj, tx, variable, name): from .builder import VariableBuilder wrapped_user_obj = VariableBuilder( @@ -2383,9 +2214,7 @@ def wrap_user_defined_obj( )._wrap(user_obj) return wrapped_user_obj - def maybe_unpack_configs( - self, configs: Any, tx: Optional["InstructionTranslator"] - ) -> list[Any]: + def maybe_unpack_configs(self, configs, tx): # unpack the list of configs configs = configs.unpack_var_sequence(tx) @@ -2394,7 +2223,7 @@ def maybe_unpack_configs( return configs - def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: + def maybe_unpack_heuristic_result(self, result: Any) -> Any: if not result.is_python_constant(): self.raise_unsupported( "@triton.heuristics must return constant values because configs can only contain constant values." @@ -2404,7 +2233,7 @@ def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: # We need to override call_getitem here so that we can add the source in the case # where we call the triton kernel with a grid - def call_getitem( # type: ignore[override] + def call_getitem( self, variable: "TritonKernelVariable", args: Sequence[Any], @@ -2422,13 +2251,7 @@ def call_getitem( # type: ignore[override] kernel_source=variable.source, ) - def call_HOP( - self, - variable: "TritonKernelVariable", - grids: Any, - combined_args_raw: dict[str, Any], - tx: "InstructionTranslator", - ) -> "variables.ConstantVariable": + def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: from .constant import ConstantVariable from .dicts import ConstDictVariable @@ -2507,9 +2330,7 @@ class TritonKernelVariable(VariableTracker): kernel_idx: Optional[int] kernel_source: "AttrSource" - def __init__( - self, kernel: Any, kernel_idx: Optional[int], grid: Any, **kwargs: Any - ) -> None: + def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: self.kernel_source = kwargs.pop("kernel_source", None) super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) @@ -2517,24 +2338,24 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value] + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + return dynamo_triton_hopifier_singleton.call_triton_kernel( self, args, kwargs, tx ) def call_method( self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": if name == "__getitem__": return dynamo_triton_hopifier_singleton.call_getitem(self, args) elif name == "run": - return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # Bail out to parent's implementation return super().call_method(tx, name, args, kwargs) @@ -2553,11 +2374,11 @@ class TMADescriptorExperimentalVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", - dims: list[VariableTracker], - block_dims: list[VariableTracker], - element_size: VariableTracker, - **kwargs: Any, - ) -> None: + dims: "list[ConstantVariable]", + block_dims: "list[ConstantVariable]", + element_size: "ConstantVariable", + **kwargs, + ): assert isinstance(data_ptr, variables.DataPtrVariable) super().__init__(**kwargs) self.data_ptr = data_ptr @@ -2565,14 +2386,14 @@ def __init__( self.block_dims = block_dims self.element_size = element_size - def to_metadata(self) -> Any: + def to_metadata(self): return create_tma_experimental_metadata( [dim.as_proxy() for dim in self.dims], [dim.as_proxy() for dim in self.block_dims], self.element_size.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", @@ -2584,28 +2405,28 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(args) codegen.call_function(len(args) + 1, False) - def get_tensor(self) -> VariableTracker: + def get_tensor(self): return self.data_ptr.from_tensor class TMADescriptorStableVariable(VariableTracker): def __init__( self, - tensor: "TensorVariable", - block_shape: "ListVariable", - **kwargs: Any, - ) -> None: + tensor: "variables.TensorVariable", + block_shape: "variables.ListVariable", + **kwargs, + ): assert isinstance(tensor, variables.TensorVariable) super().__init__(**kwargs) self.tensor = tensor self.block_shape = block_shape - def to_metadata(self) -> Any: + def to_metadata(self): return create_tma_stable_metadata( self.block_shape.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.tensor_descriptor", @@ -2617,7 +2438,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.block_shape) codegen.call_method(2) - def get_tensor(self) -> Any: + def get_tensor(self) -> "variables.TensorVariable": return self.tensor @@ -2625,7 +2446,7 @@ class CreateTMADescriptorExperimentalVariable(VariableTracker): def __init__( self, rank: int, - **kwargs: Any, + **kwargs, ) -> None: assert rank in (1, 2) super().__init__(**kwargs) @@ -2634,9 +2455,9 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] if not isinstance(ptr, variables.DataPtrVariable): @@ -2686,13 +2507,13 @@ class CreateTMADescriptorStableVariable(VariableTracker): def call_function( self, tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] return TMADescriptorStableVariable( - tensor=tensor, # type: ignore[arg-type] - block_shape=block_shape, # type: ignore[arg-type] + tensor=tensor, + block_shape=block_shape, ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 2b603a7af22a2..bdb37da3ccce1 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -590,7 +590,7 @@ def _next() -> VariableTracker: else: res = self.fn.call_function(tx, [item], {}) pred_res = variables.UserFunctionVariable( - polyfills.predicate # type: ignore[arg-type] + polyfills.predicate ).call_function(tx, [res], {}) if pred_res.as_python_constant(): return item diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e17d27c16dca6..30c1b8c2cf186 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -472,12 +472,7 @@ def call_function( ) elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] name_to_arg_map = bind_args_cached( - # pyrefly: ignore[bad-argument-type] - self.value, - tx, - self.source, - args, - kwargs, + self.value, tx, self.source, args, kwargs ) backends = name_to_arg_map["backends"].as_python_constant() set_priority = name_to_arg_map["set_priority"].as_python_constant() @@ -1434,7 +1429,7 @@ def call_function( packed_input_vt = TupleVariable.build( tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) ) - out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( tx, [packed_input_vt], {} ) assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 From 5b36e4e30fd97dac6ffcc0fc16737f9feee14cce Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 6 Nov 2025 22:06:54 +0200 Subject: [PATCH 166/651] Move AT_DISPATCH_V2 helper macros to headeronly and add THO_DISPATCH_V2_TMPL (#165856) Problem: the migration of `AT_DISPATCH_V2` macros to headeronly cannot be a simple copy-paste of macro definitions from one header file to another because the macros `AT_DISPATCH_SWITCH` and `AT_DISPATCH_CASE` may use functions that cannot be migrated to headeronly, e.g. when a selective build feature is enabled, there will be functions that are generated. On the other hand, when not using selective build, the dtype-dispatch macros are perfectly suitable for migrating to headeronly. In this PR, the migration problem above is tackled by refactoring `AT_DISPATCH` related macros into headeronly macros and non-headeronly macros while preserving the current API and semantics. For instance, consider the current V2 macro definitions: ```c++ #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__)) #define AT_AP_VAR(N, T, ...) \ AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__)) #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) ... ``` where the headeronly-migration-problematic parts are using AT_DISPATCH_SWITCH and AT_DISPATCH_CASE macros (defined in ATen/Dispatch.h). In this PR, we introduce parametric versions of `AT_DISPATCH_V2` and `AT_AP1` macros that have `_TMPL` suffices, have DISPATCH_SWITCH and DISPATCH_CASE arguments, and are define in `torch/headeronly/core/Dispatch_v2.h`: ```c++ #define THO_DISPATCH_V2_TMPL( \ DISPATCH_SWITCH, DISPATCH_CASE, TYPE, NAME, BODY, ...) \ DISPATCH_SWITCH( \ TYPE, \ NAME, \ THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__)) #define THO_AP_VAR_TMPL(C, N, T, ...) \ AT_EXPAND( \ AT_CONCAT(THO_AP, AT_NUM_ARGS(__VA_ARGS__))(C, AT_WRAP(N), __VA_ARGS__)) #define THO_AP1(C, N, _1) C(_1, N) ... ``` so that original V2 macro definition, defined in ATen/Dispatch_v2.h, becomes: ```c++ #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ THO_DISPATCH_V2_TMPL( \ AT_DISPATCH_SWITCH, \ AT_DISPATCH_CASE, \ TYPE, \ NAME, \ AT_WRAP(BODY), \ __VA_ARGS__) ``` that has exactly the same API and semantics as the original definition. Note 1: ~we have changed the definition of `AT_AP1(N, _1) ...` to `AT_AP1(C, N, _1) ...` without renaming `AT_AP1` because `AT_AP1` is a helper macro that is not a part of public API (for instance, nothing in pytorch explicitly uses `AT_AP1`).~ UPDATE: restored the original `AT_AP` macros and introduced new `THO_AP` macros. Note 2: this PR introduces a new API macro THO_DISPATCH_V2_TMPL that will be available for stable ABI users who can use it by providing custom versions of `AT_DISPATCH_SWITCH` and `AT_DISPATCH_CASE macros, say, with selective build features removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165856 Approved by: https://github.com/janeyx99 --- aten/src/ATen/Dispatch.h | 46 ++----- aten/src/ATen/Dispatch_v2.h | 53 ++----- test/cpp/aoti_abi_check/CMakeLists.txt | 1 + test/cpp/aoti_abi_check/test_dispatch.cpp | 82 +++++++++++ torch/header_only_apis.txt | 17 +++ torch/headeronly/core/Dispatch.h | 51 +++++++ torch/headeronly/core/Dispatch_v2.h | 160 ++++++++++++++++++++++ 7 files changed, 336 insertions(+), 74 deletions(-) create mode 100644 test/cpp/aoti_abi_check/test_dispatch.cpp create mode 100644 torch/headeronly/core/Dispatch.h create mode 100644 torch/headeronly/core/Dispatch_v2.h diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 40ad61cbd6455..870f7172d1622 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -6,6 +6,7 @@ #include #include #include +#include #ifdef __CUDACC__ #include // For CUDA_VERSION @@ -61,12 +62,9 @@ TORCH_API void record_kernel_function_dtype(std::string name); } \ } while (0) -#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ - case enum_type: { \ - AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ - using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT; \ - return __VA_ARGS__(); \ - } +#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__) #define AT_DISPATCH_CASE(enum_type, ...) \ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) @@ -95,14 +93,6 @@ TORCH_API void record_kernel_function_dtype(std::string name); return __VA_ARGS__(); \ } -namespace detail { - -inline at::ScalarType scalar_type(at::ScalarType s) { - return s; -} - -} // namespace detail - // The AT_DISPATCH_* family of macros provides the ability to // conveniently generate specializations of a kernel over all of the // dtypes we care about in PyTorch. We call it "dispatch" because @@ -190,27 +180,13 @@ inline at::ScalarType scalar_type(at::ScalarType s) { // but we're just being safe (and it doesn't hurt.) Note we must // use it to shut up warnings about unused store. -#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - constexpr const char* at_dispatch_name = NAME; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - at::ScalarType _st = ::detail::scalar_type(the_type); \ - RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \ - C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \ - switch (_st) { \ - __VA_ARGS__ \ - default: \ - TORCH_CHECK_NOT_IMPLEMENTED( \ - false, \ - '"', \ - at_dispatch_name, \ - "\" not implemented for '", \ - toString(_st), \ - "'"); \ - } \ - C10_DIAGNOSTIC_POP() \ - }() +#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH_TMPL( \ + RECORD_KERNEL_FUNCTION_DTYPE, \ + TORCH_CHECK_NOT_IMPLEMENTED, \ + TYPE, \ + NAME, \ + __VA_ARGS__) #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ diff --git a/aten/src/ATen/Dispatch_v2.h b/aten/src/ATen/Dispatch_v2.h index d0b77220faef2..fbeb48d45e32a 100644 --- a/aten/src/ATen/Dispatch_v2.h +++ b/aten/src/ATen/Dispatch_v2.h @@ -1,3 +1,8 @@ +#pragma once + +#include + +// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE: #include // This is a new implementation of the AT_DISPATCH macro family from @@ -74,41 +79,19 @@ // macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly // relied on GPT4 to help me get it right. -// Public API macros - // See documentation above #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__)) - -// This macro lets you pass an arbitrary expression that may contain internal -// commas to another macro without having the commas causing the expression -// to be interpreted as being multiple arguments -#define AT_WRAP(...) __VA_ARGS__ - -#define AT_FLOAT8_TYPES \ - c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \ - c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu - -#define AT_INTEGRAL_TYPES \ - c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort -#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat -#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64 -#define AT_INTEGRAL_TYPES_V2 \ - AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) -#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat -#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32 -// NB: not *actually* all types -#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES) -#define AT_ALL_TYPES_AND_COMPLEX \ - AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES) - -// Helper macros - + THO_DISPATCH_V2_TMPL( \ + AT_DISPATCH_SWITCH, \ + AT_DISPATCH_CASE, \ + TYPE, \ + NAME, \ + AT_WRAP(BODY), \ + __VA_ARGS__) + +// Unused helper macros, kept for BC: #define AT_AP_VAR(N, T, ...) \ AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__)) -#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b) -#define AT_CONCAT_AUX(a, b) a##b -#define AT_EXPAND(X) X // Ensure we never have too many scalar types for the expansion here to // support. To bump this, you must regenerate the macros below. @@ -119,12 +102,6 @@ static_assert(static_cast(c10::ScalarType::NumOptions) < 60); num_args = 60 -nums = ', '.join(str(i) for i in reversed(range(num_args+1))) -args = ', '.join(f'_{i}' for i in range(1, num_args+1)) - -print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))') -print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N') - for i in range(1, num_args+1): args = ', '.join(f'_{i}' for i in range(1, i+1)) cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)]) @@ -135,8 +112,6 @@ for i in range(1, num_args+1): // Begin generated code // clang-format off -#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) -#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) #define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) #define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index f1747acc31fc8..d618aac120dad 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -10,6 +10,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp diff --git a/test/cpp/aoti_abi_check/test_dispatch.cpp b/test/cpp/aoti_abi_check/test_dispatch.cpp new file mode 100644 index 0000000000000..5eb08d0f43b0c --- /dev/null +++ b/test/cpp/aoti_abi_check/test_dispatch.cpp @@ -0,0 +1,82 @@ +#include + +#include +#include + +// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For +// testing, we do nothing: +#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */ + +#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \ + THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \ + MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__) + +#define MY_DISPATCH_CASE(...) \ + THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__) + +// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch +// statement. For testing, we just avoid unused variable warning: +#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \ + (void)DISPATCHNAME + +// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For +// testing, we count case mismatches: +#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++ + +#define MY_DISPATCH_SWITCH(...) \ + THO_DISPATCH_SWITCH_TMPL( \ + MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__) + +// MY_CASE_FUNCTION is called in a case block. For testing, we count +// case matches and ensure that scalar_t/index_t type is defined: +#define MY_CASE_FUNCTION \ + [&] { \ + count++; \ + scalar_t tmp; \ + (void)tmp; \ + } +#define MY_INDEX_CASE_FUNCTION \ + [&] { \ + count++; \ + index_t tmp; \ + (void)tmp; \ + } + +#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE, + +#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \ + THO_DISPATCH_V2_TMPL( \ + MY_DISPATCH_SWITCH, \ + MY_DISPATCH_CASE, \ + TYPE, \ + NAME, \ + AT_WRAP(BODY), \ + __VA_ARGS__) + +#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \ + TEST(TestDispatchV2, NAME) { \ + using torch::headeronly::ScalarType; \ + using torch::headeronly::impl::ScalarTypeToCPPTypeT; \ + int8_t total_count = 0; \ + int8_t count = 0; \ + int8_t default_count = 0; \ + for (ScalarType t : \ + {AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \ + total_count++; \ + MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \ + } \ + EXPECT_EQ(count, EXPECTEDCOUNT); \ + EXPECT_EQ(default_count + count, total_count); \ + } + +TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES); +TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES); +TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES); +TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES); +TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2); +TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES); +TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES); +TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES); +TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX); + +#undef DEFINE_ITEM diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index c0cd5d9a2c689..cdc373a1b5a98 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -139,3 +139,20 @@ AT_FORALL_COMPLEX_TYPES toString << toUnderlying + +# torch/headeronly/core/Dispatch_v2.h +THO_DISPATCH_V2_TMPL +THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL +THO_DISPATCH_CASE_TMPL +THO_DISPATCH_SWITCH_TMPL +# AT_WRAP, THO_AP_VAR_TMPL, AT_CONCAT, AT_CONCAT_AUX, AT_EXPAND are tested through THO_DISPATCH_V2_TMPL +# scalar_type is tested through THO_DISPATCH_SWITCH_TMPL +AT_FLOAT8_TYPES +AT_INTEGRAL_TYPES +AT_FLOATING_TYPES +AT_BAREBONES_UNSIGNED_TYPES +AT_INTEGRAL_TYPES_V2 +AT_COMPLEX_TYPES +AT_QINT_TYPES +AT_ALL_TYPES +AT_ALL_TYPES_AND_COMPLEX diff --git a/torch/headeronly/core/Dispatch.h b/torch/headeronly/core/Dispatch.h new file mode 100644 index 0000000000000..188ac87412de1 --- /dev/null +++ b/torch/headeronly/core/Dispatch.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +// THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL is same as +// AT_PRIVATE_CASE_TYPE_USING_HINT but with a custom PRELUDE macro: +#define THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(PRELUDE, enum_type, HINT, ...) \ + case enum_type: { \ + PRELUDE(enum_type); \ + using HINT [[maybe_unused]] = \ + torch::headeronly::impl::ScalarTypeToCPPTypeT; \ + return __VA_ARGS__(); \ + } + +// THO_DISPATCH_CASE_TMPL is same as AT_DISPATCH_CASE but with a +// custom CASE_TYPE_USING_HINT macro: +#define THO_DISPATCH_CASE_TMPL(CASE_TYPE_USING_HINT, enum_type, ...) \ + CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) + +namespace detail { +inline torch::headeronly::ScalarType scalar_type( + torch::headeronly::ScalarType s) { + return s; +} +} // namespace detail + +// THO_DISPATCH_SWITCH_TMPL is same as AT_DISPATCH_SWITCH but with +// custom PRELUDE and CHECK_NOT_IMPLEMENTED macros: +#define THO_DISPATCH_SWITCH_TMPL( \ + PRELUDE, CHECK_NOT_IMPLEMENTED, TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + constexpr const char* at_dispatch_name = NAME; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + torch::headeronly::ScalarType _st = ::detail::scalar_type(the_type); \ + PRELUDE(at_dispatch_name, _st); \ + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + CHECK_NOT_IMPLEMENTED( \ + false, \ + '"', \ + at_dispatch_name, \ + "\" not implemented for '", \ + torch::headeronly::toString(_st), \ + "'"); \ + } \ + C10_DIAGNOSTIC_POP() \ + }() diff --git a/torch/headeronly/core/Dispatch_v2.h b/torch/headeronly/core/Dispatch_v2.h new file mode 100644 index 0000000000000..8ddf9d912541c --- /dev/null +++ b/torch/headeronly/core/Dispatch_v2.h @@ -0,0 +1,160 @@ +#pragma once + +#include + +// This file provides THO_DISPATCH_V2_TMPL macro that is a generalized +// version of the original AT_DISPATCH_V2 (see ATen/Dispatch_v2.h for +// documentation): THO_DISPATCH_V2_TMPL extends AT_DISPATCH_V2 with +// extra DISPATCH_SWITCH and DISPATCH_CASE arguments for specifying +// custom implementations of the original AT_DISPATCH_SWITCH and +// AT_DISPATCH_CASE macros. Use the provided macros +// THO_DISPATCH_SWITCH_TMPL and THO_DISPATCH_CASE_TMPL to define the +// custom implementations of the switch and case macros, respectively. + +// Public API macros + +// THO_DISPATCH_V2_TMPL is same as AT_DISPATCH_V2 but with custom +// DISPATCH_SWITCH and DISPATCH_CASE macro arguments: +#define THO_DISPATCH_V2_TMPL( \ + DISPATCH_SWITCH, DISPATCH_CASE, TYPE, NAME, BODY, ...) \ + DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__)) + +// This macro lets you pass an arbitrary expression that may contain internal +// commas to another macro without having the commas causing the expression +// to be interpreted as being multiple arguments +#define AT_WRAP(...) __VA_ARGS__ + +#define AT_FLOAT8_TYPES \ + torch::headeronly::ScalarType::Float8_e5m2, \ + torch::headeronly::ScalarType::Float8_e5m2fnuz, \ + torch::headeronly::ScalarType::Float8_e4m3fn, \ + torch::headeronly::ScalarType::Float8_e4m3fnuz, \ + torch::headeronly::ScalarType::Float8_e8m0fnu + +#define AT_INTEGRAL_TYPES \ + torch::headeronly::ScalarType::Byte, torch::headeronly::ScalarType::Char, \ + torch::headeronly::ScalarType::Int, torch::headeronly::ScalarType::Long, \ + torch::headeronly::ScalarType::Short +#define AT_FLOATING_TYPES \ + torch::headeronly::ScalarType::Double, torch::headeronly::ScalarType::Float +#define AT_BAREBONES_UNSIGNED_TYPES \ + torch::headeronly::ScalarType::UInt16, \ + torch::headeronly::ScalarType::UInt32, \ + torch::headeronly::ScalarType::UInt64 +#define AT_INTEGRAL_TYPES_V2 \ + AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) +#define AT_COMPLEX_TYPES \ + torch::headeronly::ScalarType::ComplexDouble, \ + torch::headeronly::ScalarType::ComplexFloat +#define AT_QINT_TYPES \ + torch::headeronly::ScalarType::QInt8, torch::headeronly::ScalarType::QUInt8, \ + torch::headeronly::ScalarType::QInt32 +// NB: not *actually* all types +#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES) +#define AT_ALL_TYPES_AND_COMPLEX \ + AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES) + +// Helper macros + +// THO_AP_VAR_TMPL is same as AT_AP_VAR but with a custom +// DISPATCH_CASE macro argument: +#define THO_AP_VAR_TMPL(C, N, T, ...) \ + AT_EXPAND( \ + AT_CONCAT(THO_AP, AT_NUM_ARGS(__VA_ARGS__))(C, AT_WRAP(N), __VA_ARGS__)) +#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b) +#define AT_CONCAT_AUX(a, b) a##b +#define AT_EXPAND(X) X + +// Ensure we never have too many scalar types for the expansion here to +// support. To bump this, you must regenerate the macros below. +static_assert(static_cast(torch::headeronly::ScalarType::NumOptions) < 60); + +// Python code to regenerate generate code below: +#if 0 + +num_args = 60 + +nums = ', '.join(str(i) for i in reversed(range(num_args+1))) +args = ', '.join(f'_{i}' for i in range(1, num_args+1)) + +print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))') +print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N') + +for i in range(1, num_args+1): + args = ', '.join(f'_{i}' for i in range(1, i+1)) + cases = ' '.join([f'C(_{j}, N)' for j in range(1, i+1)]) + print(f'#define THO_AP{i}(C, N, {args}) {cases}') + +#endif + +// Begin generated code +// clang-format off + +#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) +#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N +#define THO_AP1(C, N, _1) C(_1, N) +#define THO_AP2(C, N, _1, _2) C(_1, N) C(_2, N) +#define THO_AP3(C, N, _1, _2, _3) C(_1, N) C(_2, N) C(_3, N) +#define THO_AP4(C, N, _1, _2, _3, _4) C(_1, N) C(_2, N) C(_3, N) C(_4, N) +#define THO_AP5(C, N, _1, _2, _3, _4, _5) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) +#define THO_AP6(C, N, _1, _2, _3, _4, _5, _6) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) +#define THO_AP7(C, N, _1, _2, _3, _4, _5, _6, _7) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) +#define THO_AP8(C, N, _1, _2, _3, _4, _5, _6, _7, _8) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) +#define THO_AP9(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) +#define THO_AP10(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) +#define THO_AP11(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) +#define THO_AP12(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) +#define THO_AP13(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) +#define THO_AP14(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) +#define THO_AP15(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) +#define THO_AP16(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) +#define THO_AP17(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) +#define THO_AP18(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) +#define THO_AP19(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) +#define THO_AP20(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) +#define THO_AP21(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) +#define THO_AP22(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) +#define THO_AP23(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) +#define THO_AP24(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) +#define THO_AP25(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) +#define THO_AP26(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) +#define THO_AP27(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) +#define THO_AP28(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) +#define THO_AP29(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) +#define THO_AP30(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) +#define THO_AP31(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) +#define THO_AP32(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) +#define THO_AP33(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) +#define THO_AP34(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) +#define THO_AP35(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) +#define THO_AP36(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) +#define THO_AP37(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) +#define THO_AP38(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) +#define THO_AP39(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) +#define THO_AP40(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) +#define THO_AP41(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) +#define THO_AP42(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) +#define THO_AP43(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) +#define THO_AP44(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) +#define THO_AP45(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) +#define THO_AP46(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) +#define THO_AP47(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) +#define THO_AP48(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) +#define THO_AP49(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) +#define THO_AP50(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) +#define THO_AP51(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) +#define THO_AP52(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) +#define THO_AP53(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) +#define THO_AP54(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) +#define THO_AP55(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) +#define THO_AP56(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) +#define THO_AP57(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) +#define THO_AP58(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) +#define THO_AP59(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N) +#define THO_AP60(C, N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) C(_1, N) C(_2, N) C(_3, N) C(_4, N) C(_5, N) C(_6, N) C(_7, N) C(_8, N) C(_9, N) C(_10, N) C(_11, N) C(_12, N) C(_13, N) C(_14, N) C(_15, N) C(_16, N) C(_17, N) C(_18, N) C(_19, N) C(_20, N) C(_21, N) C(_22, N) C(_23, N) C(_24, N) C(_25, N) C(_26, N) C(_27, N) C(_28, N) C(_29, N) C(_30, N) C(_31, N) C(_32, N) C(_33, N) C(_34, N) C(_35, N) C(_36, N) C(_37, N) C(_38, N) C(_39, N) C(_40, N) C(_41, N) C(_42, N) C(_43, N) C(_44, N) C(_45, N) C(_46, N) C(_47, N) C(_48, N) C(_49, N) C(_50, N) C(_51, N) C(_52, N) C(_53, N) C(_54, N) C(_55, N) C(_56, N) C(_57, N) C(_58, N) C(_59, N) C(_60, N) + +// End generated code +// clang-format on From 552c3f3e18387ca5e1d849fee2ecfaf2938ae1d3 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 6 Nov 2025 22:06:54 +0200 Subject: [PATCH 167/651] Add THO_DISPATCH_V2 macro (#166629) The THO_DISPATCH_V2 macro is same as AT_DISPATCH_V2 but usable in headeronly context or stable ABI codes. The main difference is that AT_DISPATCH_V2 supports selective build while THO_DISPATCH_V2 does not. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166629 Approved by: https://github.com/janeyx99, https://github.com/albanD ghstack dependencies: #165856 --- test/cpp/aoti_abi_check/CMakeLists.txt | 1 + test/cpp/aoti_abi_check/test_dispatch_v2.cpp | 45 ++++++++++++++++++++ torch/header_only_apis.txt | 2 + torch/headeronly/core/Dispatch.h | 22 ++++++++++ torch/headeronly/core/Dispatch_v2.h | 10 +++++ 5 files changed, 80 insertions(+) create mode 100644 test/cpp/aoti_abi_check/test_dispatch_v2.cpp diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index d618aac120dad..9a75d19b0b069 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -11,6 +11,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp diff --git a/test/cpp/aoti_abi_check/test_dispatch_v2.cpp b/test/cpp/aoti_abi_check/test_dispatch_v2.cpp new file mode 100644 index 0000000000000..e475e9c802e32 --- /dev/null +++ b/test/cpp/aoti_abi_check/test_dispatch_v2.cpp @@ -0,0 +1,45 @@ +#include +#include +#include + +#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE, + +#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \ + TEST(TestThoDispatchV2, NAME) { \ + using torch::headeronly::ScalarType; \ + using torch::headeronly::impl::ScalarTypeToCPPTypeT; \ + int8_t total_count = 0; \ + int8_t count = 0; \ + int8_t default_count = 0; \ + for (ScalarType t : \ + {AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \ + total_count++; \ + try { \ + THO_DISPATCH_V2( \ + t, \ + "test_tho_dispatch_v2", \ + [&] { \ + count++; \ + scalar_t tmp; \ + (void)tmp; \ + }, \ + __VA_ARGS__); \ + } catch (...) { \ + default_count++; /* counts mismatches */ \ + } \ + } \ + EXPECT_EQ(count, EXPECTEDCOUNT); \ + EXPECT_EQ(default_count + count, total_count); \ + } + +TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES); +TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES); +TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES); +TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES); +TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2); +TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES); +TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES); +TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES); +TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX); + +#undef DEFINE_ITEM diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index cdc373a1b5a98..e43f3d1c10fa8 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -156,3 +156,5 @@ AT_COMPLEX_TYPES AT_QINT_TYPES AT_ALL_TYPES AT_ALL_TYPES_AND_COMPLEX +THO_DISPATCH_V2 +# THO_EMPTY, THO_DISPATCH_CASE, THO_DISPATCH_SWITCH, THO_PRIVATE_CASE_TYPE_USING_HINT are tested through THO_DISPATCH_V2 diff --git a/torch/headeronly/core/Dispatch.h b/torch/headeronly/core/Dispatch.h index 188ac87412de1..43293ef701dda 100644 --- a/torch/headeronly/core/Dispatch.h +++ b/torch/headeronly/core/Dispatch.h @@ -49,3 +49,25 @@ inline torch::headeronly::ScalarType scalar_type( } \ C10_DIAGNOSTIC_POP() \ }() + +// THO_EMPTY is a helper macro that discards its arguments. +#define THO_EMPTY(...) + +// THO_PRIVATE_CASE_TYPE_USING_HINT is same as +// AT_PRIVATE_CASE_TYPE_USING_HINT with call to macro +// AT_PRIVATE_CHECK_SELECTIVE_BUILD removed. +#define THO_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL(THO_EMPTY, enum_type, HINT, __VA_ARGS__) + +// THO_DISPATCH_SWITCH is same as AT_DISPATCH_SWITCH with call to +// macro RECORD_KERNEL_FUNCTION_DTYPE removed and using +// STD_TORCH_CHECK instead of TORCH_CHECK_NOT_IMPLEMENTED. +#define THO_DISPATCH_SWITCH(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH_TMPL(THO_EMPTY, STD_TORCH_CHECK, TYPE, NAME, __VA_ARGS__) + +// THO_DISPATCH_CASE is same as AT_DISPATCH_CASE but using +// THO_PRIVATE_CASE_TYPE_USING_HINT instead of +// AT_PRIVATE_CASE_TYPE_USING_HINT. +#define THO_DISPATCH_CASE(enum_type, ...) \ + THO_DISPATCH_CASE_TMPL( \ + THO_PRIVATE_CASE_TYPE_USING_HINT, enum_type, __VA_ARGS__) diff --git a/torch/headeronly/core/Dispatch_v2.h b/torch/headeronly/core/Dispatch_v2.h index 8ddf9d912541c..13cbd2ee85e5f 100644 --- a/torch/headeronly/core/Dispatch_v2.h +++ b/torch/headeronly/core/Dispatch_v2.h @@ -1,5 +1,6 @@ #pragma once +#include #include // This file provides THO_DISPATCH_V2_TMPL macro that is a generalized @@ -22,6 +23,15 @@ NAME, \ THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__)) +// THO_DISPATCH_V2 is same as AT_DISPATCH_V2 but using +// THO_DISPATCH_SWITCH and THO_DISPATCH_CASE instead of +// AT_DISPATCH_SWITCH and AT_DISPATCH_CASE, respectively. +#define THO_DISPATCH_V2(TYPE, NAME, BODY, ...) \ + THO_DISPATCH_V2_TMPL( \ + THO_DISPATCH_SWITCH, THO_DISPATCH_CASE, TYPE, NAME, BODY, __VA_ARGS__) + +// Type collection macros + // This macro lets you pass an arbitrary expression that may contain internal // commas to another macro without having the commas causing the expression // to be interpreted as being multiple arguments From e678450a69f6bf3b6f3ea7657d444ce9bba19940 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 7 Nov 2025 01:15:14 +0000 Subject: [PATCH 168/651] [cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (#167111) cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users to resolve https://github.com/pytorch/pytorch/issues/166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167111 Approved by: https://github.com/Skylion007 --- aten/src/ATen/Context.h | 6 ++++++ aten/src/ATen/cuda/detail/CUDAHooks.cpp | 21 +++++++++++++++++++ aten/src/ATen/cuda/detail/CUDAHooks.h | 2 ++ aten/src/ATen/detail/CUDAHooksInterface.h | 8 +++++++ aten/src/ATen/native/Convolution.cpp | 4 ++-- .../native/transformers/cuda/sdp_utils.cpp | 4 ++-- torch/csrc/cuda/shared/cudnn.cpp | 7 ++----- 7 files changed, 43 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 6807e527eb75f..385ccb88c463b 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -174,6 +174,12 @@ class TORCH_API Context { static long versionCuDNN() { return detail::getCUDAHooks().versionCuDNN(); } + static long versionRuntimeCuDNN() { + return detail::getCUDAHooks().versionRuntimeCuDNN(); + } + static long versionCuDNNFrontend() { + return detail::getCUDAHooks().versionCuDNNFrontend(); + } static bool hasCuSOLVER() { return detail::getCUDAHooks().hasCuSOLVER(); } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index b7f80101d926e..594045a1b41d2 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -21,6 +21,7 @@ #if AT_CUDNN_ENABLED() #include +#include #endif #if AT_MAGMA_ENABLED() @@ -351,6 +352,26 @@ long CUDAHooks::versionCuDNN() const { #endif } +long CUDAHooks::versionRuntimeCuDNN() const { +#if AT_CUDNN_ENABLED() +#ifndef USE_STATIC_CUDNN + return cudnnGetVersion(); +#else + return CUDNN_VERSION; +#endif +#else + TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN"); +#endif +} + +long CUDAHooks::versionCuDNNFrontend() const { +#if AT_CUDNN_ENABLED() + return CUDNN_FRONTEND_VERSION; +#else + TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN"); +#endif +} + long CUDAHooks::versionMIOpen() const { #if AT_ROCM_ENABLED() return MIOPEN_VERSION_MAJOR * 10000 + diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 8d3d1db003928..8902c68d342f8 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -49,6 +49,8 @@ struct CUDAHooks : public at::CUDAHooksInterface { bool hasCUDART() const override; long versionCUDART() const override; long versionCuDNN() const override; + long versionRuntimeCuDNN() const override; + long versionCuDNNFrontend() const override; long versionMIOpen() const override; std::string showConfig() const override; double batchnormMinEpsilonCuDNN() const override; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f1f2056917472..0ab8e82a30166 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -174,6 +174,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); } + virtual long versionRuntimeCuDNN() const { + TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP); + } + + virtual long versionCuDNNFrontend() const { + TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP); + } + virtual long versionMIOpen() const { TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 2c3f14aab911c..ca3a4f5f3faba 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -409,7 +409,7 @@ struct ConvParams { if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) { return false; } - static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN(); // broken on cuDNN 9.8 - 9.14 if (cudnn_version >= 90800 && cudnn_version < 91500) { if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous && @@ -453,7 +453,7 @@ struct ConvParams { } // native kernel doesn't support 64-bit non-splittable case if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) { - static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1; + static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1; // TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) { if (cudnn_version < 0 || cudnn_version > 91000) { diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 7fce73151b00f..a6742a7cb9e78 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { const auto s_k = params.key.sym_size(2); const auto d_qk = params.query.sym_size(3); const auto d_v = params.value.sym_size(3); - long cudnn_version = at::detail::getCUDAHooks().versionCuDNN(); + long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN(); if (cudnn_version < 8903) { if (debug) { TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher"); @@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { return false; #endif #if defined(CUDNN_VERSION) - static auto cudnn_version = cudnnGetVersion(); + static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN(); if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) { if (debug) { TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13)."); diff --git a/torch/csrc/cuda/shared/cudnn.cpp b/torch/csrc/cuda/shared/cudnn.cpp index f56899107fd56..20e69779c062b 100644 --- a/torch/csrc/cuda/shared/cudnn.cpp +++ b/torch/csrc/cuda/shared/cudnn.cpp @@ -2,6 +2,7 @@ // This file should only be compiled if this condition holds, so it should be // safe. #if defined(USE_CUDNN) || defined(USE_ROCM) +#include #include #include @@ -32,11 +33,7 @@ version_tuple getRuntimeVersion() { } size_t getVersionInt() { -#ifndef USE_STATIC_CUDNN - return cudnnGetVersion(); -#else - return CUDNN_VERSION; -#endif + return at::detail::getCUDAHooks().versionRuntimeCuDNN(); } } // namespace From b228f6d180925aee5c60f26dad8e35d12bcaeb6a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 01:17:13 +0000 Subject: [PATCH 169/651] Revert "[ROCm] Enable StaticCudaLauncher for ROCm (#166492)" This reverts commit ba2e6b0b4f1718767762d7b20558d4de943be71b. Reverted https://github.com/pytorch/pytorch/pull/166492 on behalf of https://github.com/jeffdaily due to test/inductor/test_ck_backend.py::TestCKBackend::test_max_autotune_precompile_matmul_dynamic_max_autotune_gemm_backends_CK_autotune_in_subproc_True [GH job link](https://github.com/pytorch/pytorch/actions/runs/19147453561/job/54731084387) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/ba2e6b0b4f1718767762d7b20558d4de943be71b) ([comment](https://github.com/pytorch/pytorch/pull/166492#issuecomment-3500049276)) --- test/inductor/test_codecache.py | 9 +- test/inductor/test_static_cuda_launcher.py | 21 +++- .../_inductor/runtime/static_cuda_launcher.py | 55 ++-------- torch/_inductor/runtime/triton_heuristics.py | 11 +- torch/csrc/Module.cpp | 2 +- torch/csrc/inductor/static_cuda_launcher.cpp | 102 ++---------------- torch/csrc/inductor/static_cuda_launcher.h | 2 +- 7 files changed, 41 insertions(+), 161 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index c90d2ccec83d5..46f1ca031bf83 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -475,17 +475,14 @@ def test_remote_cache_load_function( if device == GPU_TYPE and not HAS_GPU: raise unittest.SkipTest(f"requires {GPU_TYPE}") - if ( - device == "cuda" - and torch.version.hip is None - and dtype == torch.bfloat16 - and not SM80OrLater - ): + if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: raise unittest.SkipTest("requires SM80 or later") if use_static_cuda_launcher and not (device == "cuda" and bundle_triton): raise unittest.SkipTest( "Static cuda launcher requires cuda and triton bundling" ) + if use_static_cuda_launcher and TEST_WITH_ROCM: + raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM") def fn(x, y): return (x * 2, y @ y) diff --git a/test/inductor/test_static_cuda_launcher.py b/test/inductor/test_static_cuda_launcher.py index ec9586197d085..654bfd269f761 100644 --- a/test/inductor/test_static_cuda_launcher.py +++ b/test/inductor/test_static_cuda_launcher.py @@ -12,6 +12,7 @@ from torch._inductor.runtime.triton_compat import CompiledKernel, tl, triton from torch._inductor.runtime.triton_helpers import libdevice from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -38,9 +39,8 @@ def write_cubin_to_tmp(self, kernel: CompiledKernel) -> str: # Just used by tests for now. # TODO: derive cubin_path from wherever triton stores the cubin file on disk. tmp_file = tempfile.NamedTemporaryFile(mode="wb", delete=False) - binary_key = "hsaco" if torch.version.hip else "cubin" with tmp_file: - tmp_file.write(kernel.asm[binary_key]) + tmp_file.write(kernel.asm["cubin"]) self.tmp_files.append(tmp_file) return tmp_file.name @@ -64,6 +64,7 @@ def _make_launcher( result.load_kernel(device_interface.current_device()) return result + @skipIfRocm def test_basic(self): @triton.jit def simple_kernel(arg0, arg1): @@ -90,6 +91,7 @@ def simple_kernel(arg0, arg1): # 2. triton relies on inspect.get_source to get the type annotations # so I can't even use exec() to generate the test cases. # So we'll just make a few kernels by hand + @skipIfRocm def test_unsigned_integers(self): @triton.jit def unsigned_integers( @@ -113,6 +115,7 @@ def unsigned_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_signed_integers(self): @triton.jit def signed_integers( @@ -136,6 +139,7 @@ def signed_integers( launcher.run(1, 1, 1, stream, new_arg0, 50, 50, 50, 50) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_basic_1arg(self): @triton.jit def simple_kernel_1_arg(arg0): @@ -160,6 +164,7 @@ def simple_kernel_1_arg(arg0): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_constexpr(self): # Constexprs are compiled directly into the cubin file, # so we never need to pass it to StaticCudaLauncher. @@ -188,6 +193,7 @@ def kernel_constexpr(arg0, CONSTANT: tl.constexpr): ) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_implied_constant(self): """xnumel is unused in this kernel, but isn't explicitly marked as a constexpr""" @@ -240,6 +246,7 @@ def triton_red_fused_any_isinf_0( launcher.run(1, 1, 1, stream, arg0, arg2, 128) self.assertEqual(arg1, arg2) + @skipIfRocm def test_kernel_no_args(self): # Just an easy way to test incompatible number of arguments @triton.jit @@ -252,6 +259,7 @@ def kernel_no_op(): stream = device_interface.get_raw_stream(device_interface.current_device()) launcher.run(1, 1, 1, stream) + @skipIfRocm def test_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -275,6 +283,7 @@ def simple_kernel(arg0, arg1): launcher.run(1, 1, 1, stream, new_arg0, arg1) self.assertEqual(new_arg0, arg0) + @skipIfRocm def test_too_high_shared_mem(self): @triton.jit def simple_kernel(arg0, arg1): @@ -294,6 +303,7 @@ def simple_kernel(arg0, arg1): lambda: self._make_launcher(compiled_kernel), ) + @skipIfRocm def test_kernel_empty_tensor(self): # Triton kernel generated by torch.compile of the following: # @torch.compile() @@ -354,6 +364,7 @@ def triton_poi_fused_cat_0( launcher.run(1, 1, 1, stream, arg1, arg2, buf1, arg0, xnumel) self.assertEqual(buf0, buf1) + @skipIfRocm def test_kernel_many_args(self): N = 200 # Make 200 arguments @@ -394,6 +405,7 @@ class TestStaticTritonCompileResult(TestCase): Tests static cuda launcher with torch.compile() """ + @skipIfRocm def test_basic_compile(self): @torch.compile def foo(x, y): @@ -403,6 +415,7 @@ def foo(x, y): y = torch.randn(10, device="cuda") self.assertEqual(foo(x, y), x + y) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch("compile_threads", 1) def test_incompatible_code(self): @@ -425,6 +438,7 @@ def foo(x): lambda: foo(x), ) + @skipIfRocm # The error gets raised on a worker, so we want to not use a separate process @torch._inductor.config.patch( {"compile_threads": 1, "static_launch_user_defined_triton_kernels": True} @@ -446,6 +460,7 @@ def foo(x): x2 = x.clone().detach_() self.assertEqual(foo(x), x2 + 5) + @skipIfRocm def test_empty_tensor(self): @torch.compile() def foo(x, y): @@ -457,6 +472,7 @@ def foo(x, y): result = foo(x, y) self.assertEqual(result, torch.cat(((x * 4), y + 10))) + @skipIfRocm def test_any(self): def fn(x): return ( @@ -476,6 +492,7 @@ def fn(x): compiled_result = compiled_fn(arg) self.assertEqual(eager_result, compiled_result) + @skipIfRocm def test_disable_static_cuda_launcher(self): @torch.compile def fn(x, y): diff --git a/torch/_inductor/runtime/static_cuda_launcher.py b/torch/_inductor/runtime/static_cuda_launcher.py index 4eede8631e9ce..f48f351ce823a 100644 --- a/torch/_inductor/runtime/static_cuda_launcher.py +++ b/torch/_inductor/runtime/static_cuda_launcher.py @@ -38,20 +38,7 @@ def __init__(self, kernel: CompiledKernel) -> None: # pyrefly: ignore [missing-attribute] self.name = kernel.src.fn.__name__ # pyrefly: ignore [missing-attribute] - if "hsaco" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["hsaco"] - self.is_rocm = True - # pyrefly: ignore [missing-attribute] - elif "cubin" in kernel.asm: - # pyrefly: ignore [missing-attribute] - self.cubin_raw = kernel.asm["cubin"] - self.is_rocm = False - else: - raise RuntimeError( - "Expected either 'hsaco' (ROCm) or 'cubin' (CUDA) in kernel.asm" - ) - + self.cubin_raw = kernel.asm.get("cubin", None) # pyrefly: ignore [missing-attribute] self.cubin_path = kernel._cubin_path @@ -258,42 +245,12 @@ def run( # thing, it should always match. # Get rid of constants before passing to cubin launcher + # Add a None if triton wants extra parameters for scratch spaces arg_tys = self.arg_tys - - if self.is_rocm: - # ROCm/HIP kernel ABI: The Triton HIP backend ALWAYS includes both - # global_scratch and profile_scratch parameters in the kernel signature, - # even when the kernel doesn't use them (i.e., when has_*_scratch is False). - # - # This differs fundamentally from CUDA, where these parameters are only - # present in the signature if the corresponding has_*_scratch flag is True. - # - # The flags indicate whether memory will be allocated/used: - # - has_global_scratch: Whether global scratch workspace is needed - # - has_profile_scratch: Whether profiling instrumentation is enabled - # - # However, regardless of flag values, we MUST always pass both parameters - # to match the HIP kernel ABI. Passing None is safe: - # - # - If scratch is not needed (has_*_scratch=False or scratch_size=0): - # The None becomes nullptr, which the kernel never dereferences - # - # - If scratch is needed (has_*_scratch=True and scratch_size>0): - # The None becomes nullptr initially, but the HIP runtime intercepts - # the kernel launch, allocates the required scratch memory based on - # kernel metadata, and replaces the nullptr with a valid pointer before - # the kernel actually executes - # - # Not passing both parameters causes segmentation faults because the kernel - # expects them at specific positions in the argument array. - arg_tys = arg_tys + "OO" - args = (*args, None, None) - - else: - for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: - if has_scratch: - arg_tys = arg_tys + "O" - args = (*args, None) + for has_scratch in [self.has_global_scratch, self.has_profile_scratch]: + if has_scratch: + arg_tys = arg_tys + "O" + args = (*args, None) # pyrefly: ignore [bad-argument-type] assert len(args) == len(arg_tys) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index d60cda3fae7bf..b38cdcb71fa23 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1599,8 +1599,9 @@ def can_statically_launch( return None def check_can_launch() -> StaticallyLaunchedCudaKernel: - if triton_meta.get("device_type") not in ("cuda", "hip"): - raise CannotStaticallyLaunchKernel("Non-cuda/ROCm device") + if triton_meta.get("device_type") != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") if torch._inductor.config.cpp_wrapper: # If we're running with cpp wrapper, it doesn't @@ -1626,11 +1627,10 @@ def check_can_launch() -> StaticallyLaunchedCudaKernel: "static launch does not support launch attributes" ) - binary_ext = "hsaco" if triton_meta.get("device_type") == "hip" else "cubin" cubin_location = os.path.join( triton_cache_dir(triton_meta.get("device", 0)), triton_hash_to_path_key(kernel.hash), - f"{kernel.src.fn.__name__}.{binary_ext}", + f"{kernel.src.fn.__name__}.cubin", ) if not os.path.exists(cubin_location): @@ -1662,11 +1662,10 @@ def reload_cubin_path(self): When loading from cache on disk, we want to reload cubin files from their appropriate location on disc. """ - binary_ext = "hsaco" if torch.version.hip else "cubin" cubin_location = os.path.join( triton_cache_dir(self.compile_meta.get("device", 0)), triton_hash_to_path_key(self.kernel.hash), - f"{self.kernel.name}.{binary_ext}", + f"{self.kernel.name}.cubin", ) if not os.path.exists(cubin_location): if self.kernel.cubin_raw is not None: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 0c32e6028bc69..ad37abe3b560b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2159,7 +2159,7 @@ PyObject* initModule() { #ifdef USE_CUDA torch::cuda::initModule(module); #endif -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) ASSERT_TRUE(StaticCudaLauncher_init(module)); #endif #ifdef USE_MPS diff --git a/torch/csrc/inductor/static_cuda_launcher.cpp b/torch/csrc/inductor/static_cuda_launcher.cpp index 35d11c8651323..59916b6763bfa 100644 --- a/torch/csrc/inductor/static_cuda_launcher.cpp +++ b/torch/csrc/inductor/static_cuda_launcher.cpp @@ -1,4 +1,7 @@ -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) && !defined(USE_ROCM) +// We disable this file from being hipified because there are CUDA drivers hip +// has not implemented yet. Also, we're passing in a cubin file directly, so it +// would take more work to support ROCM anyway. #include #include @@ -13,11 +16,6 @@ #include #include #include - -#if defined(USE_ROCM) -#include -#endif - /** Implements a static launcher for triton compiled CUDA kernels. Given a path to a cubin file, a function name, and some metadata, @@ -58,14 +56,8 @@ const at::cuda::NVRTC& nvrtc() { CUdeviceptr getPointer(PyObject* obj) { CUdeviceptr data_ptr = 0; - if (THPUtils_checkLong(obj)) { -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(obj)); -#else data_ptr = THPUtils_unpackUInt64(obj); -#endif - return data_ptr; } if (obj == Py_None) { @@ -81,25 +73,13 @@ CUdeviceptr getPointer(PyObject* obj) { TORCH_CHECK( THPUtils_checkLong(ret), "data_ptr method of Pointer object must return 64-bit int"); - -#if defined(USE_ROCM) - data_ptr = reinterpret_cast(THPUtils_unpackUInt64(ret)); -#else data_ptr = THPUtils_unpackUInt64(ret); -#endif - if (!data_ptr) return data_ptr; CUdeviceptr dev_ptr = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipPointerGetAttribute( - &dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuPointerGetAttribute( &dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, data_ptr)); -#endif - return dev_ptr; } @@ -118,15 +98,6 @@ CUfunction loadKernel( } CUmodule mod = nullptr; CUfunction func = nullptr; - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipModuleLoad(&mod, filePath.c_str())); - AT_CUDA_DRIVER_CHECK(hipModuleGetFunction(&func, mod, funcName.c_str())); - int shared_optin = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_optin, hipDeviceAttributeSharedMemPerBlockOptin, device)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoad(&mod, filePath.c_str())); AT_CUDA_DRIVER_CHECK( nvrtc().cuModuleGetFunction(&func, mod, funcName.c_str())); @@ -135,9 +106,6 @@ CUfunction loadKernel( &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device)); - -#endif - // Shared memory logic from triton/third-party/nvidia/backend/driver.c // If we're using more than 48 KB of shared memory, and we have // access to more than 48 KB of shared memory on the device, @@ -156,21 +124,6 @@ CUfunction loadKernel( " Reducing block sizes or `num_stages` may help."); if (sharedMemBytes > SHARED_MEM_STATIC_MAX && shared_optin > SHARED_MEM_STATIC_MAX) { -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipFuncSetCacheConfig(func, hipFuncCachePreferShared)); - int shared_total = 0, shared_static = 0; - AT_CUDA_DRIVER_CHECK(hipDeviceGetAttribute( - &shared_total, - hipDeviceAttributeMaxSharedMemoryPerMultiprocessor, - device)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &shared_static, HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, func)); - AT_CUDA_DRIVER_CHECK(hipFuncSetAttribute( - func, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_optin - shared_static)); - -#else AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncSetCacheConfig(func, CU_FUNC_CACHE_PREFER_SHARED)); int shared_total = 0, shared_static = 0; @@ -184,7 +137,6 @@ CUfunction loadKernel( func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); -#endif } return func; } @@ -200,27 +152,6 @@ inline void launchKernel( cudaStream_t stream) { // cta_args is always 1 for inductor generated triton kernels, // so we don't need to figure out grid dimension here -#if defined(USE_ROCM) - int device = 0; - AT_CUDA_DRIVER_CHECK(hipGetDevice(&device)); - int warp_size = 0; - AT_CUDA_DRIVER_CHECK( - hipDeviceGetAttribute(&warp_size, hipDeviceAttributeWarpSize, device)); - - AT_CUDA_DRIVER_CHECK(hipModuleLaunchKernel( - func, - gridX, - gridY, - gridZ, - warp_size * numWarps, // blockDim.x - 1, // blockDim.y - 1, // blockDim.z - sharedMemBytes, - stream, - args, - nullptr)); - -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( func, gridX, @@ -233,7 +164,6 @@ inline void launchKernel( stream, args, nullptr)); -#endif } template @@ -339,20 +269,11 @@ PyObject* load_kernel(PyObject* self, PyObject* args) { CUdevice device = static_cast(device_ptr); // NOLINT CUfunction func = nullptr; func = loadKernel(filePath, funcName, sharedMemBytes, device); - -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK( - hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, func)); - AT_CUDA_DRIVER_CHECK(hipFuncGetAttribute( - &n_spills, HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#else + // Taken from triton/nvidia/backend/driver.c AT_CUDA_DRIVER_CHECK( nvrtc().cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, func)); AT_CUDA_DRIVER_CHECK(nvrtc().cuFuncGetAttribute( &n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, func)); - -#endif n_spills /= 4; // Return a tuple of CUFunction, n_regs, n_spills return Py_BuildValue( @@ -378,6 +299,7 @@ PyObject* launch_kernel_inner( std::array argStorage = {}; std::array kernelArgs = {}; parseKernelArgs(varArgs, argTypes, argStorage.data(), kernelArgs.data()); + launchKernel( func, gridX, @@ -464,25 +386,13 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) { Py_RETURN_NONE; } CUcontext pctx = nullptr; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipCtxGetCurrent(&pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); -#endif - if (!pctx) { // Ensure device context exists CUdevice device = 0; -#if defined(USE_ROCM) - AT_CUDA_DRIVER_CHECK(hipDeviceGet(&device, 0)); - AT_CUDA_DRIVER_CHECK(hipDevicePrimaryCtxRetain(&pctx, device)); - AT_CUDA_DRIVER_CHECK(hipCtxSetCurrent(pctx)); -#else AT_CUDA_DRIVER_CHECK(nvrtc().cuDeviceGet(&device, 0)); AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxRetain(&pctx, device)); AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxSetCurrent(pctx)); - -#endif } CUfunction func = reinterpret_cast(func_ptr); // NOLINT cudaStream_t cudaStream = reinterpret_cast(stream); // NOLINT diff --git a/torch/csrc/inductor/static_cuda_launcher.h b/torch/csrc/inductor/static_cuda_launcher.h index 6f3980172275b..517036b9975e6 100644 --- a/torch/csrc/inductor/static_cuda_launcher.h +++ b/torch/csrc/inductor/static_cuda_launcher.h @@ -1,5 +1,5 @@ #pragma once -#if defined(USE_CUDA) +#if defined(USE_CUDA) && !defined(USE_ROCM) #include #include From 31ac76423917ddac34c22beb6fbada2955b11f43 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 01:21:15 +0000 Subject: [PATCH 170/651] Revert "Move enrich_profiler_metadata config import out of gm.recompile() (#167114)" This reverts commit d144382dc96f109a6254c38734779e0a09fb7134. Reverted https://github.com/pytorch/pytorch/pull/167114 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/167114#issuecomment-3500057321)) --- test/test_fx.py | 6 +++--- torch/_dynamo/config.py | 7 +++++-- torch/fx/experimental/_config.py | 8 +------- torch/fx/graph_module.py | 10 ++++------ 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 7b075c7f73381..3ad21e64c8ce2 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4251,7 +4251,7 @@ def fn(a, b, c, d): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) + @torch._dynamo.config.patch("enrich_profiler_metadata", True) def test_profiler_stack_trace_augmentation(self): """ Test that map_recorded_events_to_aten_ops_with_stack_trace correctly @@ -4307,7 +4307,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) + @torch._dynamo.config.patch("enrich_profiler_metadata", True) def test_profiler_multiple_modules(self): """ Test that multiple compiled modules under the same profiler session @@ -4351,7 +4351,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) + @torch._dynamo.config.patch("enrich_profiler_metadata", True) def test_profiler_nested_graph_modules(self): """ Test that nested graph modules (e.g., graph modules calling subgraphs) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 66142b196d630..0c95408401c79 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -739,8 +739,11 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None -# Deprecated! Please use the config in torch/fx/experimental/_config instead. -enrich_profiler_metadata: bool = False +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index a537978db3834..ce4296b6410c9 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -2,8 +2,6 @@ import sys from typing import Optional -from torch.utils._config_module import Config, install_config_module - # [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. no_data_dependent_graph_break = ( @@ -102,11 +100,7 @@ # Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. skip_dtype_check_in_meta_registrations = False -# Experimental: If True, graph module will register fx metadata during recompile() -enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] - default=False, - env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", -) +from torch.utils._config_module import install_config_module install_config_module(sys.modules[__name__]) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index ab33d7bf321c9..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -20,7 +20,6 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_importer from ._compatibility import compatibility -from .experimental import _config as fx_experimental_config from .graph import ( _BoxedCodeGen, _custom_builtins, @@ -859,15 +858,14 @@ def recompile(self) -> PythonCode: called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ - # Do not import anything inside recompile, it might slow down the - # function and cause perf regression. Import outside of the method instead. if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec + from torch._dynamo import config as dynamo_config + python_code = self._graph.python_code( - root_module="self", - record_func=fx_experimental_config.enrich_profiler_metadata, + root_module="self", record_func=dynamo_config.enrich_profiler_metadata ) self._code = python_code.src self._lineno_map = python_code._lineno_map @@ -876,7 +874,7 @@ def recompile(self) -> PythonCode: cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - if fx_experimental_config.enrich_profiler_metadata: + if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation node_metadata: dict[int, dict[str, Any]] = {} for i, node in enumerate(self._graph.nodes): From 0e512ee9f05347de25e8d359c3d69f2f5d371398 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Fri, 7 Nov 2025 01:43:25 +0000 Subject: [PATCH 171/651] Make pyrefly installable by lintrunner on Python-3.14 (#167270) By pinning numpy to 2.3.4 for 3.14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167270 Approved by: https://github.com/huydhn --- .lintrunner.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index cee0249ad96eb..c7e3797c9b80c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -143,7 +143,8 @@ init_command = [ 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', 'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"', - 'numpy==2.1.0 ; python_version >= "3.12"', + 'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"', + 'numpy==2.3.4 ; python_version >= "3.14"', 'expecttest==0.3.0', 'pyrefly==0.36.2', 'sympy==1.13.3', From 292bd62c71897e6fca81f77b08bd77158393ddd9 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Thu, 6 Nov 2025 13:36:00 +0000 Subject: [PATCH 172/651] Introduce TEST_ACCELERATOR and TEST_MULTIACCELERATOR to simplify UT (#167196) # Motivation This PR aims to introduce two variables (`TEST_ACCELERATOR` and `TEST_MULTIACCELERATOR`) to simplify UT generalization. Since out-of-tree backends may be imported later, these variables are defined as lazy values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167196 Approved by: https://github.com/albanD --- .../pipelining/test_schedule_multiproc.py | 2 +- test/distributed/pipelining/test_stage.py | 2 +- test/test_accelerator.py | 13 +++- torch/testing/_internal/common_utils.py | 75 ++++++++++--------- 4 files changed, 51 insertions(+), 41 deletions(-) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 9806bb5d03874..5538e750d27eb 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -46,6 +46,7 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) @@ -56,7 +57,6 @@ torch.manual_seed(0) device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" backend = dist.get_default_backend_for_device(device_type) -TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 @dataclass diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 1e6dad4a77d77..b9ad3d5cb6771 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -24,6 +24,7 @@ parametrize, run_tests, skip_but_pass_in_sandcastle_if, + TEST_MULTIACCELERATOR, ) from torch.utils._pytree import tree_map_only @@ -34,7 +35,6 @@ device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" backend = dist.get_default_backend_for_device(device_type) -TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2 torch.manual_seed(0) diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 21731bd275b60..d44c8b0d350c9 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -5,17 +5,22 @@ import unittest import torch -from torch.testing._internal.common_utils import NoTest, run_tests, TEST_MPS, TestCase +from torch.testing._internal.common_utils import ( + NoTest, + run_tests, + TEST_ACCELERATOR, + TEST_MPS, + TEST_MULTIACCELERATOR, + TestCase, +) -if not torch.accelerator.is_available(): +if not TEST_ACCELERATOR: print("No available accelerator detected, skipping tests", file=sys.stderr) TestCase = NoTest # noqa: F811 # Skip because failing when run on cuda build with no GPU, see #150059 for example sys.exit() -TEST_MULTIACCELERATOR = torch.accelerator.device_count() > 1 - class TestAccelerator(TestCase): def test_current_accelerator(self): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 00572f9691380..06bbd329d3450 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1468,6 +1468,44 @@ def is_privateuse1_backend_available(): return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available() +def make_lazy_class(cls): + + def lazy_init(self, cb): + self._cb = cb + self._value = None + + cls.__init__ = lazy_init + + for basename in [ + "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow", + "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert", + "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index", + ]: + name = f"__{basename}__" + + def inner_wrapper(name): + use_operator = basename not in ("bool", "int") + + def wrapped(self, *args, **kwargs): + if self._cb is not None: + self._value = self._cb() + self._cb = None + if not use_operator: + return getattr(self._value, name)(*args, **kwargs) + else: + return getattr(operator, name)(self._value, *args, **kwargs) + return wrapped + + setattr(cls, name, inner_wrapper(name)) + + return cls + + +@make_lazy_class +class LazyVal: + pass + + IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' TEST_NUMPY = _check_module_exists('numpy') @@ -1480,6 +1518,8 @@ def is_privateuse1_backend_available(): TEST_XPU = torch.xpu.is_available() TEST_HPU = bool(hasattr(torch, "hpu") and torch.hpu.is_available()) TEST_CUDA = torch.cuda.is_available() +TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available()) # type: ignore[call-arg] +TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1) # type: ignore[call-arg] custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) TEST_PRIVATEUSE1 = is_privateuse1_backend_available() TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() @@ -5601,37 +5641,7 @@ def _skip_helper(self, op, device, dtype): if not op.supports_autograd and not op.supports_forward_ad: self.skipTest("Skipped! autograd not supported.") -def make_lazy_class(cls): - - def lazy_init(self, cb): - self._cb = cb - self._value = None - - cls.__init__ = lazy_init - - for basename in [ - "add", "sub", "mul", "truediv", "floordiv", "mod", "divmod", "pow", - "lshift", "rshift", "and", "or", "xor", "neg", "pos", "abs", "invert", - "eq", "ne", "lt", "le", "gt", "ge", "bool", "int", "index", - ]: - name = f"__{basename}__" - - def inner_wrapper(name): - use_operator = basename not in ("bool", "int") - - def wrapped(self, *args, **kwargs): - if self._cb is not None: - self._value = self._cb() - self._cb = None - if not use_operator: - return getattr(self._value, name)(*args, **kwargs) - else: - return getattr(operator, name)(self._value, *args, **kwargs) - return wrapped - - setattr(cls, name, inner_wrapper(name)) - return cls # Base TestCase for NT tests; used to define common helpers, etc. @@ -5676,11 +5686,6 @@ def branch_nested_state(self): nested_tensor_module._tensor_symint_registry = original_tensor_symint_registry -@make_lazy_class -class LazyVal: - pass - - def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=0): from torch._dynamo.trace_rules import _as_posix_path From ae67a5a9d3ef05a586928a1b9e2a6f11ed35cc77 Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Fri, 7 Nov 2025 02:42:05 +0000 Subject: [PATCH 173/651] [ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#167233) * `c10::fetch_and_cast` and `c10::cast_and_store` produce branchy code since it supports all datatypes * So, we do special handling for binary elementwise broadcast with mixed dtypes of float/bfloat16/half * This improves performance Pull Request resolved: https://github.com/pytorch/pytorch/pull/167233 Approved by: https://github.com/jeffdaily --- aten/src/ATen/native/cuda/CUDALoops.cuh | 89 +++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index c42d03b9cbf7f..b83ec3c761e9b 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -884,6 +884,69 @@ struct type_specialized_kernel_launcher { } }; +template +struct type_specialized_broadcast_kernel_launcher { + template < + typename func_t, + typename array_t, + typename dtypes_t, + typename calc_t> + static void apply( + int64_t numel, + func_t f, + array_t data, + dtypes_t dtypes, + calc_t offset_calc) { + using traits = function_traits; + using ret_t = typename traits::result_type; + using arg0_t = typename traits::template arg<0>::type; + using arg1_t = typename traits::template arg<1>::type; + if (dtypes[0] == rt_binary_specializations[arg_index][0] && + dtypes[1] == rt_binary_specializations[arg_index][1] && + dtypes[2] == rt_binary_specializations[arg_index][2]) { + using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT; + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + auto u = c10::load(data[1] + offsets0[1]); + auto v = c10::load(data[2] + offsets0[2]); + ret_t result0 = f(c10::convert(u), c10::convert(v)); + auto u1 = c10::load(data[1] + offsets1[1]); + auto v1 = c10::load(data[2]+ offsets1[2]); + ret_t result1 = f(c10::convert(u1), c10::convert(v1)); + auto u2 = c10::load(data[1] + offsets2[1]); + auto v2 = c10::load(data[2] + offsets2[2]); + ret_t result2 = f(c10::convert(u2), c10::convert(v2)); + auto u3 = c10::load(data[1] + offsets3[1]); + auto v3 = c10::load(data[2] + offsets3[2]); + ret_t result3 = f(c10::convert(u3), c10::convert(v3)); + *(ret_cpp_t*)out0 = c10::convert(result0); + *(ret_cpp_t*)out1 = c10::convert(result1); + *(ret_cpp_t*)out2 = c10::convert(result2); + *(ret_cpp_t*)out3 = c10::convert(result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + auto u = c10::load(data[1] + offsets[1]); + auto v = c10::load(data[2] + offsets[2]); + ret_t result = f(c10::convert(u), c10::convert(v)); + *(ret_cpp_t*)out = c10::convert(result); + } + }); + } + } +}; + } // namespace #endif @@ -1002,6 +1065,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { } auto offset_calc = ::make_offset_calculator(iter); #ifdef USE_ROCM + if (check_binary_rt_types_for_specialization(iter)) { + // constexpr to reduce the amount of kernels generated for + // broadcast elementwise with mexed dtypes and limit which functors are actually + // applied to the load and store at compile time. + using func_tuple = typename traits::ArgsTuple; + if constexpr ( + std::is_same_v && traits::arity == 2 && + check_binary_functor_types_for_specialization< + func_tuple, + float, + float, + traits::arity, + /*arg_num=*/0>::check()) { + memory::detail::static_unroll< + type_specialized_broadcast_kernel_launcher, + rt_binary_specializations.size()>::with_args( + numel, + f, + data, + dtypes, + offset_calc + ); + return; + } + } + constexpr int grp_sz = 128; launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { if (unrl) { From 0e1f76f77ecce947d884d769bf0b3d64fbff1c8f Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 6 Nov 2025 21:21:47 -0300 Subject: [PATCH 174/651] Add two new docker images with Python 3.11/3.12 (#167092) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167092 Approved by: https://github.com/malfet, https://github.com/atalman --- .ci/docker/build.sh | 12 ++++++++++++ .github/workflows/docker-builds.yml | 2 ++ 2 files changed, 14 insertions(+) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index f0b9a788758ca..7d55884fbe431 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -168,6 +168,18 @@ case "$tag" in VISION=yes TRITON=yes ;; + pytorch-linux-jammy-py3.11-clang12) + ANACONDA_PYTHON_VERSION=3.11 + CLANG_VERSION=12 + VISION=no + TRITON=no + ;; + pytorch-linux-jammy-py3.12-clang12) + ANACONDA_PYTHON_VERSION=3.12 + CLANG_VERSION=12 + VISION=no + TRITON=no + ;; pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then ANACONDA_PYTHON_VERSION=3.10 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 941a045649f3a..0aa176cd1c676 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -56,6 +56,8 @@ jobs: pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9, pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, + pytorch-linux-jammy-py3.11-clang12, + pytorch-linux-jammy-py3.12-clang12, pytorch-linux-jammy-py3.13-clang12, pytorch-linux-jammy-py3.14-clang12, pytorch-linux-jammy-rocm-n-py3, From 1632876edf2cd08a47349a7ba978f3e14ab5b18d Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 02:49:11 +0000 Subject: [PATCH 175/651] [3/N] Use key in dict for existence checks (#167214) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167214 Approved by: https://github.com/Lucaskabela --- torch/_decomp/decompositions_for_jvp.py | 2 +- .../passes/replace_with_hop_pass_util.py | 2 +- torch/_functorch/functional_call.py | 2 +- torch/_higher_order_ops/utils.py | 2 +- torch/_higher_order_ops/wrap.py | 4 +-- torch/_inductor/codecache.py | 6 ++--- torch/_inductor/codegen/cpp_wrapper_cpu.py | 4 +-- torch/_inductor/codegen/cuda/cuda_kernel.py | 2 +- .../codegen/cuda/cutlass_python_evt.py | 2 +- torch/_inductor/codegen/halide.py | 2 +- torch/_inductor/codegen/mps.py | 4 +-- torch/_inductor/codegen/simd.py | 2 +- torch/_inductor/compile_fx.py | 2 +- torch/_inductor/fx_passes/b2b_gemm.py | 2 +- .../_inductor/fx_passes/group_batch_fusion.py | 8 +++--- torch/_inductor/fx_passes/misc_patterns.py | 2 +- torch/_inductor/fx_passes/numeric_utils.py | 2 +- .../_inductor/fx_passes/overlap_scheduling.py | 2 +- torch/_inductor/fx_passes/split_cat.py | 27 +++++++------------ torch/_inductor/graph.py | 4 +-- torch/_inductor/lowering.py | 10 +++---- torch/_inductor/runtime/triton_heuristics.py | 6 ++--- torch/_inductor/scheduler.py | 2 +- torch/_inductor/sizevars.py | 2 +- torch/_inductor/utils.py | 2 +- torch/_library/infer_schema.py | 2 +- torch/_logging/_internal.py | 2 +- torch/_namedtensor_internals.py | 4 +-- torch/_numpy/testing/utils.py | 2 +- torch/_prims_common/wrappers.py | 4 +-- torch/_weights_only_unpickler.py | 2 +- torch/cuda/_sanitizer.py | 2 +- .../_shard/sharded_tensor/__init__.py | 2 +- .../checkpoint/_consolidate_hf_safetensors.py | 2 +- torch/distributed/checkpoint/state_dict.py | 14 +++++----- torch/distributed/device_mesh.py | 2 +- torch/distributed/distributed_c10d.py | 20 +++++++------- torch/distributed/fsdp/_optim_utils.py | 6 ++--- torch/distributed/fsdp/wrap.py | 2 +- torch/distributed/nn/api/remote_module.py | 2 +- torch/distributed/optim/named_optimizer.py | 2 +- torch/distributed/pipelining/_IR.py | 2 +- torch/distributed/pipelining/_backward.py | 2 +- torch/distributed/rpc/backend_registry.py | 2 +- torch/distributed/rpc/internal.py | 2 +- torch/distributed/tensor/parallel/style.py | 2 +- torch/export/passes/__init__.py | 2 +- torch/export/unflatten.py | 2 +- torch/masked/maskedtensor/core.py | 2 +- .../package/file_structure_representation.py | 2 +- torch/package/package_exporter.py | 2 +- torch/testing/_internal/common_device_type.py | 2 +- torch/testing/_internal/common_fsdp.py | 2 +- .../testing/_internal/common_quantization.py | 2 +- torch/testing/_internal/common_utils.py | 4 +-- .../_internal/distributed/distributed_test.py | 6 ++--- .../_internal/distributed/rpc/rpc_test.py | 2 +- .../_internal/optests/generate_tests.py | 4 +-- 58 files changed, 99 insertions(+), 118 deletions(-) diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index fb4a4d85faa20..dd3b7e7d88992 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -84,7 +84,7 @@ def _register_jit_decomposition_for_jvp(decomp, use_python=False): # Thanks copilot! def get_function_def(sig): param_def = [f"{param_str}" for param_str in sig.parameters.values()] - param_use = [f"{param_str}" for param_str in sig.parameters.keys()] + param_use = [f"{param_str}" for param_str in sig.parameters] return f"def wrapped_decomp({', '.join(param_def)}):\n return decomp_fn({', '.join(param_use)})\n" diff --git a/torch/_export/passes/replace_with_hop_pass_util.py b/torch/_export/passes/replace_with_hop_pass_util.py index 6ea3f1adde4f8..862244aac8837 100644 --- a/torch/_export/passes/replace_with_hop_pass_util.py +++ b/torch/_export/passes/replace_with_hop_pass_util.py @@ -71,7 +71,7 @@ def set_hoo_node_meta(call_func_node): # Rename the name of getitem nodes to the actual name of its contents # for passing verifier and better readability, also propagate metadata - for get_item_node in call_func_node.users.keys(): + for get_item_node in call_func_node.users: idx: int = get_item_node.args[1] # type: ignore[assignment] output_node = output_args[idx] get_item_node._rename(output_node.name) diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 55f45c9256962..8e2f943d3e447 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -131,7 +131,7 @@ def compute_loss(params, x, t): raise ValueError( "Expected all elements of parameter_and_buffer_dicts to be dictionaries" ) - all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] + all_keys = [k for d in parameter_and_buffer_dicts for k in d] all_keys_counter: dict[str, int] = {} for k in all_keys: v = all_keys_counter.get(k, 0) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 160e149fd769f..fad19b1d5ffae 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -337,7 +337,7 @@ def analyze_potential_input_alias_or_mutation(name, aliases, input_mutations): raise RuntimeError( f"{name} where aliases appear. " + f"In particular, these inputs \ - {set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map.keys())} " # noqa: C401 + {set(el for el_map in aliases if len(el_map.keys()) > 0 for el in el_map)} " # noqa: C401 + "get aliased. Please ensure that this doesn't happen." ) if len(input_mutations): diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index ba6bbe0c39b6b..0dbc378716797 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -228,10 +228,10 @@ def divide_kwargs(kwargs): checkpoint_keys.add("preserve_rng_state") checkpoint_kwargs = { - name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys + name: kwargs[name] for name in kwargs if name in checkpoint_keys } gmod_kwargs = { - name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys + name: kwargs[name] for name in kwargs if name not in checkpoint_keys } return checkpoint_kwargs, gmod_kwargs diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index f36953d2a3337..1d985d6aa35da 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2151,7 +2151,7 @@ def get_zero_consts_asm_code( ) all_cuda = all( graph.get_original_value_of_constant(name).is_cuda - for name in graph.constants.keys() + for name in graph.constants if name not in graph.folded_constants ) @@ -2192,7 +2192,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: ): serialized_weights = b"".join( _to_bytes(graph.get_original_value_of_constant(name), all_cuda) - for name in graph.constants.keys() + for name in graph.constants if name not in graph.folded_constants ) else: @@ -2206,7 +2206,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: graph.get_original_value_of_constant(name), TensorProperties(graph.constants[name]), ) - for name in graph.constants.keys() + for name in graph.constants if name not in graph.folded_constants } ) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1b994dcf3ffa6..be87044a74e1c 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -631,7 +631,7 @@ def write_wrapper_decl(self): debug_printer_manager.codegen_model_inputs_value_print( input_args_to_print=[ input_key - for input_key in V.graph.graph_inputs.keys() + for input_key in V.graph.graph_inputs if input_key.startswith("arg") ] ) @@ -811,7 +811,7 @@ def codegen_model_constructor(self): all_cuda = all( V.graph.get_original_value_of_constant(name).is_cuda - for name in V.graph.constants.keys() + for name in V.graph.constants if name not in V.graph.folded_constants ) for idx, name in enumerate(V.graph.constants.keys()): diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index a76e77dbe75ae..97643ef00a7bd 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -312,7 +312,7 @@ def def_kernel( size_vars.extend(str(s) for s in free_symbols) self.size_args.extend(free_symbols) size_args = [f"const int {s}" for s in size_vars] - offset_args = [f"const int {name}_offset" for name in self.named_nodes.keys()] + offset_args = [f"const int {name}_offset" for name in self.named_nodes] runtime_arg_decls = ",".join( [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] ) diff --git a/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/torch/_inductor/codegen/cuda/cutlass_python_evt.py index 72108b29b3cb0..e6b7d2afe6c39 100644 --- a/torch/_inductor/codegen/cuda/cutlass_python_evt.py +++ b/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -168,7 +168,7 @@ def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]) self.removed_buffers: OrderedSet[str] = removed_buffers self.cur_node: Optional[ComputedBuffer] = None self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs - for name in V.graph.constants.keys(): + for name in V.graph.constants: self.name_to_buffer[name] = V.graph.add_tensor_constant( V.graph.constants[name], name ) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 495b9c04f75fc..e47e8e6d7841d 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -906,7 +906,7 @@ def setup_dom_indexing(self): return self.dom_renames[prefix] renames = {} - for var in self.halide_vars.keys(): + for var in self.halide_vars: if not self.inside_reduction and var in self.reduction_renames: continue m = re.match(r"^h(\d+)$", var.name) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 4c668ea194409..8b72a8c97df28 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -955,7 +955,7 @@ def call_kernel( """ wrapper = V.graph.wrapper_code # Make sure sizevars has been computed - for v in self.args.sizevars.keys(): + for v in self.args.sizevars: wrapper.ensure_size_computed(v) _, call_args, _, arg_types = self.args.python_argdefs() @@ -965,7 +965,7 @@ def call_kernel( args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] args = [arg for arg in args if arg not in self.removed_buffers] - args += [str(v) for v in self.args.sizevars.keys()] + args += [str(v) for v in self.args.sizevars] arg_types = [arg_name_to_type[arg] for arg in args] # Add any dynamic ints as inputs diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index f062bf12f1778..24394bc87cf41 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2052,7 +2052,7 @@ def _codegen_single_template( # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. - for input_name in kernel.named_input_nodes.keys(): + for input_name in kernel.named_input_nodes: subgraph_name = f"" # pyrefly: ignore [missing-attribute] partial_code.finalize_hook(subgraph_name, strict=False) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 8ff19b8721067..55951be231f5f 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -214,7 +214,7 @@ def _fx_compile_mode_default() -> FxCompileConfig: "Invalid value of %s for %s. Expected one of %s. Using default.", value, name, - ", ".join(sorted(repr(x) for x in FxCompileMode.__members__.keys())), + ", ".join(sorted(repr(x) for x in FxCompileMode.__members__)), ) # Remove from the environment so subprocesses don't ALSO complain. os.environ.pop(name) diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 9faec788e9e3a..5a8dc65c08ec4 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -641,7 +641,7 @@ def all_reach_via_pointwise_with_no_other_inputs( if node is dst: visited.add(node) elif (node is src) or is_pointwise_node(node): - for user in node.users.keys(): + for user in node.users: # for nodes other than dst, bookkeep their users' input counts if user not in input_counter: input_counter[user] = len(user.all_input_nodes) diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 295c720382853..f46d4d3ba216f 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -198,7 +198,7 @@ def match(self, node: torch.fx.Node) -> tuple[str, int, int, int, bool, str] | N return None # get the user of the node if self.graph_search_options.get("fuse_nodes_with_same_users", False): - users = [user.target for user in node.users.keys()] + users = [user.target for user in node.users] else: users = "" # type: ignore[assignment] # only handle the cases where inputs are 2D tensors @@ -627,7 +627,7 @@ def match(self, node: torch.fx.Node): weight = get_arg_value(node, 1, "weight") bias = get_arg_value(node, 2, "bias") if self.graph_search_options.get("fuse_nodes_with_same_users", False): - users = [user.target for user in node.users.keys()] + users = [user.target for user in node.users] else: users = "" # type: ignore[assignment] group_key = ( @@ -742,7 +742,7 @@ def match(self, node: torch.fx.Node): weight = get_arg_value(node, 2, "weight") bias = get_arg_value(node, 3, "bias") if self.graph_search_options.get("fuse_nodes_with_same_users", False): - users = [user.target for user in node.users.keys()] + users = [user.target for user in node.users] else: users = "" # type: ignore[assignment] group_key = ( @@ -1425,7 +1425,7 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): } non_fbgemm_fusions = { fusion: config.post_grad_fusion_options[fusion] - for fusion in config.post_grad_fusion_options.keys() + for fusion in config.post_grad_fusion_options if fusion not in fbgemm_fusion_keys } fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) diff --git a/torch/_inductor/fx_passes/misc_patterns.py b/torch/_inductor/fx_passes/misc_patterns.py index 2159e8811ad9e..ff0981e72e8b2 100644 --- a/torch/_inductor/fx_passes/misc_patterns.py +++ b/torch/_inductor/fx_passes/misc_patterns.py @@ -113,7 +113,7 @@ def __call__(self, graph: torch.fx.Graph): signatures = () if signatures is None else signatures replaceable_kwargs = OrderedSet() for sig in signatures: - for param_name in sig.parameters.keys(): + for param_name in sig.parameters: if param_name in self.numpy_compat: replaceable_kwargs.update(self.numpy_compat[param_name]) diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index b50859448f072..d1db82f21f7ec 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -49,7 +49,7 @@ def compare_dict_tensors(dict_base, dict_control, precision): logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) return False is_allclose = True - for key in dict_base.keys(): + for key in dict_base: if key not in dict_control: logger.warning( "Mismatch parameter name %s does not exist after pre/post grad fx passes", diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index f383ab63dc261..80ef2a95139a3 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -736,7 +736,7 @@ def should_assume_bucketed(self, node: fx.Node) -> bool: if key is None: return False - for in_flight_coll in self.in_flight.keys(): + for in_flight_coll in self.in_flight: if bucket_key(in_flight_coll, mode="custom_ops_multidtype") == key: return True diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 0bad4fa7cc635..6347bda3b525c 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -404,8 +404,8 @@ def normalize_stack_default(match: Match, *args, **kwargs): def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]: next_users = [] - for getitem_node in split_node.users.keys(): - for getitem_user in getitem_node.users.keys(): + for getitem_node in split_node.users: + for getitem_user in getitem_node.users: if getitem_user not in next_users: next_users.append(getitem_user) return next_users @@ -623,7 +623,7 @@ def merge_splits( ) first_split_num_to_user = { user.args[1]: user - for user in first_split.users.keys() # type: ignore[union-attr] + for user in first_split.users # type: ignore[union-attr] } new_split_num = 0 @@ -637,9 +637,7 @@ def merge_splits( old_getitem.update_arg(1, new_split_num) new_split_num += 1 else: - next_split_num_to_user = { - user.args[1]: user for user in node.users.keys() - } + next_split_num_to_user = {user.args[1]: user for user in node.users} # It is not necessary all getitems from the split node are used. for next_split_num in range(len(next_split_sections)): with graph.inserting_after(new_split): @@ -1160,9 +1158,7 @@ def remove_unbind( return # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node # before we do the unbind remove, otherwise it will hit the error when we unbind part of them - getitem_indices = [ - getitem_node.args[1] for getitem_node in unbind_node.users.keys() - ] + getitem_indices = [getitem_node.args[1] for getitem_node in unbind_node.users] if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] getitem_indices ) != len(unbind_node.meta["example_value"]): @@ -1314,10 +1310,7 @@ def merge_split_squeeze( split_input.meta["example_value"], dim=dim ) for item_index, getitem_node in sorted( - [ - (getitem_node.args[1], getitem_node) - for getitem_node in split.users.keys() - ] + [(getitem_node.args[1], getitem_node) for getitem_node in split.users] ): squeeze = next(iter(getitem_node.users.keys())) new_get_item = graph.call_function( @@ -2753,14 +2746,12 @@ def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: # cat_arg must be the split input view_shape_list = [] - for user in cat_arg.users.keys(): + for user in cat_arg.users: if user.target is torch.split: - for getitem in user.users.keys(): + for getitem in user.users: if getitem.target is operator.getitem: reshape_user = [ - user - for user in getitem.users.keys() - if user.target is torch.reshape + user for user in getitem.users if user.target is torch.reshape ] if len(reshape_user) > 0: view_shape_list = list( diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 28e7f88d33986..3f71a53d3fa65 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1936,7 +1936,7 @@ def format_new_defs() -> str: # we already know facts for. renamed_unbacked_bindings = OrderedSet( V.fake_mode.shape_env.unbacked_renamings.get(s, s) - for s in unbacked_bindings.keys() + for s in unbacked_bindings ) assert new_unbacked_defs >= renamed_unbacked_bindings, ( @@ -2481,7 +2481,7 @@ def is_unspec_arg(self, name: str) -> bool: # dynamo wraps unspec variable as 0d CPU tensor, # need to convert to scalar during codegen (triton only) return ( - name in self.graph_inputs.keys() + name in self.graph_inputs and self.graph_inputs[name].get_numel() == 1 and len(self.graph_inputs[name].get_size()) == 0 and get_device_type(self.graph_inputs[name]) == "cpu" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index f6ad1028ca12d..2df224caf61a9 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -482,9 +482,7 @@ def wrapped(*args, **kwargs): (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn ): # explicitly assert for "out=" ops for better error messages - assert not any(x == "out" for x in kwargs.keys()), ( - "out= ops aren't yet supported" - ) + assert not any(x == "out" for x in kwargs), "out= ops aren't yet supported" args, kwargs = transform_args( args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool @@ -2705,9 +2703,7 @@ def constrain_to_fake_tensor(arg, fake_arg): ] return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr) if isinstance(arg, dict): - return { - key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg.keys() - } + return {key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg} elif isinstance(arg, (tuple, list)): return type(arg)( constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg) @@ -2732,7 +2728,7 @@ def apply_constraint(arg, fx_arg): ) return ir.ExternKernel.require_stride_order(arg, stride_order) if isinstance(arg, dict): - return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()} + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} return arg args = tuple( diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index b38cdcb71fa23..363d62d02303d 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1425,7 +1425,7 @@ def filtered_signature() -> list[str]: # These are torch compiled triton kernels that definitely # have block size configs. Dynamo does not currently # trace user defined triton kernels when TRITON_INTERPRET=1 - if x not in cfg.kwargs.keys(): + if x not in cfg.kwargs: new_signature.append(x) elif i not in get_constexprs(self.fn): # use constexprs rather than just configs since user @@ -2562,7 +2562,7 @@ def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Conf } assert all( - block_type in configs[0].kwargs for block_type in tma_min_block_sizes.keys() + block_type in configs[0].kwargs for block_type in tma_min_block_sizes ) # Add a config that is guaranteed to compile @@ -3199,7 +3199,7 @@ def reduction( assert triton_meta is not None num_dynamic = 0 - for k in triton_meta["signature"].keys(): + for k in triton_meta["signature"]: if "ks" in k: num_dynamic += 1 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index d7e3ed5a529d1..020067c83999c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2899,7 +2899,7 @@ def __add__(self, other: DedupList[_T]) -> DedupList[_T]: list1 = name_to_users[buf1_name] list2 = name_to_users[buf2_name] combined = list1 + list2 - for key in name_to_users.keys(): + for key in name_to_users: if ( name_to_users[key] is list1 or name_to_users[key] is list2 diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 35313f472f430..77526a38aeb37 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -882,7 +882,7 @@ def _choose(x: int, y: int) -> bool: # Start building the unbacked replacements mapping using CanonicalExprFinder # The mapping is from Expr to its "canonical" Expr. self.unbacked_replacements = {} - for expr in self.equality_graph.keys(): + for expr in self.equality_graph: canonical_expr = uf.find_expr(expr) if expr != canonical_expr: self.unbacked_replacements[expr] = canonical_expr diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 1a43e938d7146..2b7a9541aa875 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1226,7 +1226,7 @@ def unload_xpu_triton_pyds() -> None: if not module_name.startswith("torch._inductor.runtime.compile_tasks."): continue m = sys.modules[module_name] - for attr_name in m.__dict__.keys(): + for attr_name in m.__dict__: if attr_name.startswith("triton_"): kernel = getattr(m, attr_name) if isinstance( diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index cb3cfd1d6029f..8c10a23dab881 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -142,7 +142,7 @@ def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: list_type = tuple_to_list(annotation_type) example_type_str = "\n\n" # Only suggest the list type if this type is supported. - if list_type in SUPPORTED_PARAM_TYPES.keys(): + if list_type in SUPPORTED_PARAM_TYPES: example_type_str = f"For example, {list_type}.\n\n" error_fn( f"Parameter {name} has unsupported type {param.annotation}. " diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 04298b7cdac84..93e2d8dc29b0c 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -498,7 +498,7 @@ def _set_logs(**kwargs) -> None: if val not in logging._levelToName: raise ValueError( f"Unrecognized log level for log {alias}: {val}, valid level values " - f"are: {','.join([str(k) for k in logging._levelToName.keys()])}" + f"are: {','.join([str(k) for k in logging._levelToName])}" ) log_state.enable_log( diff --git a/torch/_namedtensor_internals.py b/torch/_namedtensor_internals.py index 16d04f181525d..b0fa6a206fac3 100644 --- a/torch/_namedtensor_internals.py +++ b/torch/_namedtensor_internals.py @@ -93,9 +93,9 @@ def update_names_with_list(tensor, names, inplace): def update_names_with_mapping(tensor, rename_map, inplace): dim_map = build_dim_map(tensor) - for old_dim in rename_map.keys(): + for old_dim in rename_map: new_dim = rename_map[old_dim] - if old_dim in dim_map.keys(): + if old_dim in dim_map: dim_map[old_dim] = new_dim else: raise RuntimeError( diff --git a/torch/_numpy/testing/utils.py b/torch/_numpy/testing/utils.py index d43f63f10388c..ffc027043b6f5 100644 --- a/torch/_numpy/testing/utils.py +++ b/torch/_numpy/testing/utils.py @@ -207,7 +207,7 @@ def assert_equal(actual, desired, err_msg="", verbose=True): if not isinstance(actual, dict): raise AssertionError(repr(type(actual))) assert_equal(len(actual), len(desired), err_msg, verbose) - for k in desired.keys(): + for k in desired: if k not in actual: raise AssertionError(repr(k)) assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 941fb6ee68e84..e369481c1044b 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -133,7 +133,7 @@ def _fn(*args, **kwargs): type_promoting_args = tuple( bound.arguments[x] for x in self.type_promoting_arg_names # type: ignore[union-attr] - if x in bound.arguments.keys() + if x in bound.arguments ) flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) @@ -145,7 +145,7 @@ def _fn(*args, **kwargs): promoted_args = { x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) for x in self.type_promoting_arg_names # type: ignore[union-attr] - if x in bound.arguments.keys() + if x in bound.arguments } bound.arguments.update(promoted_args) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 1ac9d2046f242..5aaa77b25697a 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -187,7 +187,7 @@ def _get_allowed_globals(): } # dtype - for t in torch.storage._dtype_to_storage_type_map().keys(): + for t in torch.storage._dtype_to_storage_type_map(): rc[str(t)] = t for t in torch.storage._new_dtypes(): rc[str(t)] = t diff --git a/torch/cuda/_sanitizer.py b/torch/cuda/_sanitizer.py index 90953d888d6c2..8f215a730923b 100644 --- a/torch/cuda/_sanitizer.py +++ b/torch/cuda/_sanitizer.py @@ -303,7 +303,7 @@ def stream_wait_for_event(self, stream: StreamId, event: EventId) -> None: def all_streams_wait_for_event(self, event: EventId) -> None: self._ensure_event_exists(event) - for stream in self.current_sync_states.keys(): + for stream in self.current_sync_states: self.stream_wait_for_event(stream, event) self._state_wait_for_other( diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index e1e9983d52628..3d3af3ed35953 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -437,7 +437,7 @@ def pre_load_state_dict_hook( Pre-load state dict hook to add ShardedTensor to the module. """ for submodule_name, submodule in module.named_modules(): - for attr_name in submodule.__dict__.keys(): + for attr_name in submodule.__dict__: mod_prefix = prefix + submodule_name key = mod_prefix + ("." if mod_prefix else "") + attr_name if key in state_dict: diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 9d70ab7c7400d..32d81fb1ea721 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -261,7 +261,7 @@ def _process_output_file( file_metadata = input_files_data[safetensors_file].metadata input_metadata_size = input_files_data[safetensors_file].metadata_size - if tensor_fqn not in file_metadata.keys(): + if tensor_fqn not in file_metadata: continue metadata = file_metadata[tensor_fqn] diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 54a29c0bb3588..6a31144348acb 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -521,7 +521,7 @@ def verify(key, fqn) -> bool: if info.submodule_prefixes: new_state_dict: dict[str, ValueType] = {} # TODO: make this faster. - for fqn in state_dict.keys(): + for fqn in state_dict: for prefix in info.submodule_prefixes: if not fqn.startswith(prefix): continue @@ -826,7 +826,7 @@ def _reconstruct_nested_dict( # the state_dict. if fqn in info.shared_params_mapping: in_params = False - for k in param_group.keys(): + for k in param_group: if k == _PARAMS: continue flatten_key = f"{_PG}.{fqn}.{k}" @@ -850,7 +850,7 @@ def _reconstruct_nested_dict( # Reconstruct state for this parameter state[fqn] = {} - for state_name in optim.state[param].keys(): + for state_name in optim.state[param]: flattened_state_key = f"{_STATE}.{fqn}.{state_name}" if flattened_state_key not in state_dict: @@ -868,7 +868,7 @@ def _reconstruct_nested_dict( ] first_param_fqn = cast(list[str], pg_state[-1][_PARAMS])[0] - for k in param_group.keys(): + for k in param_group: if k == _PARAMS: continue value = state_dict[f"{_PG}.{first_param_fqn}.{k}"] @@ -980,9 +980,7 @@ def _split_optim_state_dict( return_osd: OptimizerStateType = {_STATE: state, _PG: pg_state} pg_mapping: dict[int, int] = {} - if all( - isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE]).keys() - ): + if all(isinstance(k, int) for k in cast(DictValueType, optim_state_dict[_STATE])): return optim_state_dict for param_group in optim.param_groups: @@ -1139,7 +1137,7 @@ def _device(t): # dissimilar parameters in comparison to optim_state_dict. This is achieved by # incorporating differential parameters within local, which may result in optim # having additional parameters ultimately. - for optim_key in flatten_osd.keys(): + for optim_key in flatten_osd: if optim_key not in flatten_local_osd: if optim_key not in osd_mapping: raise AssertionError( diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index a161a4394a93d..9a3a0c8883bbc 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -644,7 +644,7 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: root_mesh = self._get_root_mesh() root_to_flatten_mapping = root_mesh._flatten_mapping - if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys(): + if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping: dim_group_name = root_to_flatten_mapping[ mesh_dim # type: ignore[index] ]._dim_group_names[0] diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 415cbacc177a8..801716e3855ac 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -700,7 +700,7 @@ def pg_config_info(self) -> list[dict[str, Any]]: """ config_info: list[dict[str, Any]] = [] default_pg_size = _get_group_size(None) - for pg in self.pg_map.keys(): + for pg in self.pg_map: ranks = self.pg_group_ranks[pg] config_info.append( { @@ -1461,9 +1461,7 @@ def _get_all_pg_configs() -> list[dict[str, Any]]: Return the pg configuration of all the process groups. """ - config_info: list[dict[str, Any]] = [ - _get_pg_config(pg) for pg in _world.pg_map.keys() - ] + config_info: list[dict[str, Any]] = [_get_pg_config(pg) for pg in _world.pg_map] return config_info @@ -1520,7 +1518,7 @@ def _add_ephemeral_timeout_for_all_pgs(timeout: timedelta) -> None: Returns: None. """ - for pg in _world.pg_map.keys(): + for pg in _world.pg_map: devices = pg._device_types if torch.device("cuda") in devices: backend = pg._get_backend(torch.device("cuda")) @@ -2180,7 +2178,7 @@ def _new_process_group_helper( # register only a single backend when all get_device_backend_map values are the same if len(set(backend_config.get_device_backend_map().values())) == 1: - for device in backend_config.get_device_backend_map().keys(): + for device in backend_config.get_device_backend_map(): pg._register_backend(torch.device(device), backend_type, backend_class) # break out of outer loop to not create any more backends @@ -2287,7 +2285,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): del _world.pg_names[pg] del _world.pg_group_ranks[pg] del _world.pg_backend_config[pg] - if pg in _world.pg_coalesce_state.keys(): + if pg in _world.pg_coalesce_state: warnings.warn( "Some coalesced collectives haven't been launched when " "ProcessGroup is destroyed. They will be cleaned.", @@ -2379,7 +2377,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None): del _world.pg_names[pg] del _world.pg_group_ranks[pg] del _world.pg_backend_config[pg] - if pg in _world.pg_coalesce_state.keys(): + if pg in _world.pg_coalesce_state: warnings.warn( "Some coalesced collectives haven't been launched when " "ProcessGroup is aborted. They will be cleaned.", @@ -2994,7 +2992,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): if group is None: group = _get_default_group() - if group in _world.pg_coalesce_state.keys(): + if group in _world.pg_coalesce_state: # We are in coalescing context, do not issue single operation, just append a collective representation coll = _CollOp(all_reduce, tensor, None, op, None) _world.pg_coalesce_state[group].append(coll) @@ -4112,7 +4110,7 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal group = group or _get_default_group() - if group in _world.pg_coalesce_state.keys(): + if group in _world.pg_coalesce_state: # We are in coalescing context, do not issue single operation, just append a collective representation coll = _CollOp(all_gather_into_tensor, input_tensor, output_tensor) _world.pg_coalesce_state[group].append(coll) @@ -4577,7 +4575,7 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F # Check if we are in coalescing context # If we are, do not issue single operation, just append a collective representation - if group in _world.pg_coalesce_state.keys(): + if group in _world.pg_coalesce_state: coll = _CollOp(reduce_scatter_tensor, input, output, op, None) _world.pg_coalesce_state[group].append(coll) if async_op: diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 96657eeea4106..564cfeece48ee 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -469,7 +469,7 @@ def _flatten_optim_state_dict( for fqn in fqns: if not unflat_osd_state[fqn]: continue - for state_name in unflat_osd_state[fqn].keys(): + for state_name in unflat_osd_state[fqn]: unflat_osd_state[fqn][state_name] = _broadcast_state( fsdp_state, unflat_osd_state[fqn][state_name], group=group ) @@ -1377,9 +1377,7 @@ def _convert_all_state_info( for fqn, gathered_state in output_states.items(): state_info = [s[fqn] for s in gathered_state_info] - all_tensor_states = sorted( - {n for state in state_info for n in state.tensors.keys()} - ) + all_tensor_states = sorted({n for state in state_info for n in state.tensors}) empty_ranks: set[int] = set() dtype: Optional[torch.dtype] = None # First check all the non-scalar states and get the information of diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index f0a210eca8a6b..f731854dab2eb 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -586,7 +586,7 @@ def enable_autowrap_context(kwargs: Any) -> None: ) _ConfigAutoWrap.in_autowrap_context = True # Get and save the wrapper cls for the context. - if "wrapper_cls" not in kwargs.keys(): + if "wrapper_cls" not in kwargs: raise AssertionError( "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." ) diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 066197fad24a7..d2db28d4371de 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -493,7 +493,7 @@ def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda def _check_attribute_picklability(self): """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability).""" - for k in self.__dict__.keys(): + for k in self.__dict__: if ( k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING diff --git a/torch/distributed/optim/named_optimizer.py b/torch/distributed/optim/named_optimizer.py index b5135ae5411ef..c2384dabd9dad 100644 --- a/torch/distributed/optim/named_optimizer.py +++ b/torch/distributed/optim/named_optimizer.py @@ -203,7 +203,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: for idx, param_key in enumerate(self.ordered_param_keys): # When the conditional training is performed, not all parameters are updated in the optim. - if param_key not in state.keys(): + if param_key not in state: continue if len(state[param_key]) != len(new_state[idx]): raise ValueError( diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 62e3764abe055..120c717755c78 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -108,7 +108,7 @@ def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec) generated_spec = TrivialLossWrapper.loss_spec elif output_loss_value_spec is None: # Use default spec, i.e. search for "loss" in output values - if isinstance(output_val, dict) and "loss" in output_val.keys(): + if isinstance(output_val, dict) and "loss" in output_val: loss_node = output_val["loss"] generated_spec = {k: k == "loss" for k in output_val} else: diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 38d30c793e89d..e34460449e1e0 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -337,7 +337,7 @@ def extract_tensors_with_grads( return assert isinstance(grad_val, dict) assert set(output_val.keys()) == set(grad_val.keys()) - for k in output_val.keys(): + for k in output_val: extract_tensors_with_grads( output_val[k], grad_val[k], extract_tensors_with_grads ) diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 16299404c6b65..3f30252bd8256 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -58,7 +58,7 @@ def backend_registered(backend_name): True if the backend has been registered with ``register_backend``, else False. """ - return backend_name in BackendType.__members__.keys() + return backend_name in BackendType.__members__ def register_backend( diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index c830fc11d8edd..faef8afddfc2c 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -122,7 +122,7 @@ def serialize(self, obj): p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] # Install customized picklers. - for class_name in self._class_reducer_dict.keys(): + for class_name in self._class_reducer_dict: p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] # save _thread_local_tensor_tables.send_tables if it is in nested call diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index 032179bafa3eb..182a3fbcafebf 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -560,7 +560,7 @@ def _prepare_input_fn(self, inputs, device_mesh): def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) prepared_kwarg_inputs = {} - for kwarg_key in kwarg_inputs.keys(): + for kwarg_key in kwarg_inputs: kwarg_val = kwarg_inputs[kwarg_key] input_layout = self.input_kwarg_layouts.get(kwarg_key) desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index 90430608cab21..9de0bea443920 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -30,7 +30,7 @@ def _get_new_device( location: Union[torch.device, str, dict[str, str]], ) -> str: if isinstance(location, dict): - if str(curr_device) in location.keys(): + if str(curr_device) in location: return location[str(curr_device)] else: return str(curr_device) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 3701ba99047fb..a3f86fabceb7b 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1331,7 +1331,7 @@ def get_actual_output_node(output): else: graph_outputs = [] # Iterate through nodes we have copied into self.graph. - for orig_node in self.node_map.keys(): + for orig_node in self.node_map: for user_node in orig_node.users: if user_node.name not in self.seen_nodes: # external user node, need to expose as an output diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 111680c1f019e..cad5621b29bd6 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -88,7 +88,7 @@ def _helper(a, map_fn): for a in args: impl_args.append(_helper(a, map_fn)) impl_kwargs = {} - for k in kwargs.keys(): + for k in kwargs: impl_kwargs[k] = _helper(a, map_fn) return impl_args, impl_kwargs diff --git a/torch/package/file_structure_representation.py b/torch/package/file_structure_representation.py index 8ef00e0159d8b..2dae130ed6007 100644 --- a/torch/package/file_structure_representation.py +++ b/torch/package/file_structure_representation.py @@ -55,7 +55,7 @@ def has_file(self, filename: str) -> bool: lineage = filename.split("/", maxsplit=1) child = lineage[0] grandchildren = lineage[1] if len(lineage) > 1 else None - if child in self.children.keys(): + if child in self.children: if grandchildren is None: return True else: diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index b25ebca23095f..cea4335f75a70 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -1157,7 +1157,7 @@ def get_rdeps(self, module_name: str) -> list[str]: Returns: A list containing the names of modules which depend on ``module_name``. """ - if module_name in self.dependency_graph._pred.keys(): + if module_name in self.dependency_graph._pred: return list(self.dependency_graph._pred[module_name].keys()) else: return [] diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index c31d7a54b65a1..9acc6f0f75676 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1516,7 +1516,7 @@ def __init__(self, d): assert isinstance(d, dict), ( "precisionOverride not given a dtype : precision dict!" ) - for dtype in d.keys(): + for dtype in d: assert isinstance(dtype, torch.dtype), ( f"precisionOverride given unknown dtype {dtype}" ) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 36c72f1d5c3be..74b3cdc78f2d9 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -209,7 +209,7 @@ def _broadcast_state_dict(rank, state_dict): dist.broadcast_object_list(olist) state_dict = cast(dict[str, torch.Tensor], olist[0]) # Ensure that the state is on DEVICE - for param_name in state_dict.keys(): + for param_name in state_dict: state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE) return state_dict diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index c88f7ad45c7ea..5f4fab8c48bbd 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -750,7 +750,7 @@ def is_leaf_module(module): and not isinstance(module, torch.nn.Sequential) and type(module) in propagate_qconfig_list ) - or type(module) in float_to_observed_module_class_mapping.keys() + or type(module) in float_to_observed_module_class_mapping ) and not isinstance(module, torch.ao.quantization.DeQuantStub) ): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 06bbd329d3450..8f4f8efc4108e 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2833,7 +2833,7 @@ def matches_test(target: str): # parametrized ones (TestSuite disables TestSuiteCPU) return classname.startswith(target_classname) and (target_testname in (test._testMethodName, sanitized_testname)) - if any(matches_test(x) for x in slow_tests_dict.keys()): + if any(matches_test(x) for x in slow_tests_dict): getattr(test, test._testMethodName).__dict__['slow_test'] = True if not TEST_WITH_SLOW: raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") @@ -3000,7 +3000,7 @@ def _to_number(self, number_like, *, id): return int(number_like) # type: ignore[call-overload] else: number = super()._to_number(number_like, id=id) - if type(number) not in self._TYPE_TO_DTYPE.keys(): + if type(number) not in self._TYPE_TO_DTYPE: self._inputs_not_supported() return number diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 8cb9c929d8545..503e15af4bb3e 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7486,7 +7486,7 @@ def forward(self, x, rank): # iterate offset//2 more times than rank 0, to test nodes # depleting inputs at different times. if num_early_join_ranks > 1: - for rank in mapping.keys(): + for rank in mapping: if rank > 0: mapping[rank] += offset // 2 mapping.update( @@ -7888,8 +7888,8 @@ def custom_type_validator(x): return x.t def dict_validator(x): - self.assertTrue(EXPECTED_FIELDS[0] in x.keys()) - self.assertTrue(EXPECTED_FIELDS[1] in x.keys()) + self.assertTrue(EXPECTED_FIELDS[0] in x) + self.assertTrue(EXPECTED_FIELDS[1] in x) self.assertEqual(1, len({t.device for t in x.values()})) self.assertEqual(x[EXPECTED_FIELDS[0]].device.index, self.rank) return x[EXPECTED_FIELDS[0]] + x[EXPECTED_FIELDS[1]] diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 21464e514742c..14d16281c14e2 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3285,7 +3285,7 @@ def test_debug_info(self): for key in expected: self.assertIn(key, info.keys()) - for key in info.keys(): + for key in info: self.assertIn(key, expected.keys()) @dist_init(setup_rpc=False) diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 17f7e27d67463..398425853f09a 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -496,7 +496,7 @@ def __init__( def maybe_raise_errors_on_exit(self) -> None: # Check expected failures first - for qualname in self.seen_ops_to_errors.keys(): + for qualname in self.seen_ops_to_errors: option = self.failures_dict.get_status(qualname, self.test_name) if len(self.seen_ops_to_errors[qualname]) == 0: if should_update_failures_dict(): @@ -518,7 +518,7 @@ def maybe_raise_errors_on_exit(self) -> None: ) continue failed_ops = [] - for qualname in self.seen_ops_to_errors.keys(): + for qualname in self.seen_ops_to_errors: option = self.failures_dict.get_status(qualname, self.test_name) if option != "xsuccess": continue From a913b2bb938e6a8ccfa581fb24dc3f7fbaac7bd3 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 03:22:57 +0000 Subject: [PATCH 176/651] [2/N] Add return types of Python functions (#167203) This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff ANN rules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167203 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan --- torch/utils/_cpp_extension_versioner.py | 2 +- torch/utils/_debug_mode.py | 30 ++++----- torch/utils/_device.py | 2 +- torch/utils/_ordered_set.py | 2 +- torch/utils/_python_dispatch.py | 14 ++-- .../_strobelight/cli_function_profiler.py | 2 +- .../examples/cli_function_profiler_example.py | 4 +- torch/utils/_sympy/functions.py | 10 +-- torch/utils/_sympy/numbers.py | 4 +- torch/utils/_sympy/singleton_int.py | 2 +- torch/utils/_thunk.py | 2 +- torch/utils/_traceback.py | 6 +- torch/utils/_zip.py | 2 +- torch/utils/backcompat/__init__.py | 4 +- torch/utils/backend_registration.py | 22 +++---- torch/utils/benchmark/examples/compare.py | 4 +- torch/utils/benchmark/examples/fuzzer.py | 2 +- .../utils/benchmark/examples/op_benchmark.py | 6 +- .../benchmark/examples/sparse/compare.py | 6 +- .../utils/benchmark/examples/sparse/fuzzer.py | 2 +- .../benchmark/examples/sparse/op_benchmark.py | 6 +- .../examples/spectral_ops_fuzz_test.py | 2 +- torch/utils/benchmark/op_fuzzers/binary.py | 2 +- .../benchmark/op_fuzzers/sparse_binary.py | 2 +- torch/utils/benchmark/op_fuzzers/spectral.py | 2 +- torch/utils/benchmark/op_fuzzers/unary.py | 2 +- torch/utils/benchmark/utils/compare.py | 22 +++---- torch/utils/benchmark/utils/compile.py | 4 +- torch/utils/benchmark/utils/fuzzer.py | 12 ++-- torch/utils/benchmark/utils/sparse_fuzzer.py | 2 +- .../utils/valgrind_wrapper/timer_interface.py | 2 +- torch/utils/collect_env.py | 2 +- torch/utils/cpp_extension.py | 36 +++++------ torch/utils/data/_utils/pin_memory.py | 4 +- torch/utils/data/_utils/signal_handling.py | 4 +- torch/utils/data/_utils/worker.py | 12 ++-- torch/utils/data/backward_compatibility.py | 2 +- torch/utils/data/dataframes_pipes.ipynb | 2 +- torch/utils/data/datapipes/_hook_iterator.py | 6 +- torch/utils/data/datapipes/_typing.py | 10 +-- .../datapipes/dataframe/dataframe_wrapper.py | 2 +- .../data/datapipes/dataframe/datapipes.py | 12 ++-- torch/utils/data/datapipes/datapipe.py | 32 +++++----- torch/utils/data/datapipes/gen_pyi.py | 2 +- .../data/datapipes/iter/combinatorics.py | 2 +- torch/utils/data/datapipes/iter/combining.py | 28 ++++---- torch/utils/data/datapipes/iter/fileopener.py | 2 +- .../utils/data/datapipes/iter/streamreader.py | 2 +- torch/utils/data/datapipes/map/combining.py | 2 +- torch/utils/data/datapipes/utils/decoder.py | 6 +- torch/utils/data/dataset.py | 10 +-- torch/utils/data/graph.py | 2 +- torch/utils/data/graph_settings.py | 2 +- torch/utils/data/standard_pipes.ipynb | 2 +- torch/utils/data/typing.ipynb | 14 ++-- torch/utils/file_baton.py | 8 +-- torch/utils/flop_counter.py | 8 +-- torch/utils/hipify/hipify_python.py | 32 +++++----- torch/utils/hooks.py | 10 +-- torch/utils/mkldnn.py | 14 ++-- torch/utils/model_dump/__init__.py | 2 +- torch/utils/module_tracker.py | 12 ++-- torch/utils/show_pickle.py | 16 ++--- torch/utils/tensorboard/_embedding.py | 8 +-- torch/utils/tensorboard/_pytorch_graph.py | 20 +++--- torch/utils/tensorboard/summary.py | 2 +- torch/utils/tensorboard/writer.py | 64 +++++++++---------- torch/utils/throughput_benchmark.py | 10 +-- torch/utils/viz/_cycles.py | 32 +++++----- torch/utils/weak.py | 28 ++++---- torch/xpu/__init__.py | 20 +++--- torch/xpu/random.py | 10 +-- torch/xpu/streams.py | 4 +- 73 files changed, 342 insertions(+), 342 deletions(-) diff --git a/torch/utils/_cpp_extension_versioner.py b/torch/utils/_cpp_extension_versioner.py index 2997f90d7c89d..d1391dd9aaab0 100644 --- a/torch/utils/_cpp_extension_versioner.py +++ b/torch/utils/_cpp_extension_versioner.py @@ -27,7 +27,7 @@ def hash_build_arguments(hash_value, build_arguments): class ExtensionVersioner: - def __init__(self): + def __init__(self) -> None: self.entries = {} def get_version(self, name): diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5a6ee246abf7e..5c8bc9221a957 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -60,7 +60,7 @@ def _stringify_dtensor_spec(spec) -> str: class TensorIdTracker: - def __init__(self): + def __init__(self) -> None: self.tensor_memo: dict[WeakIdRef, int] = {} self.next_tensor_id = 0 @@ -68,7 +68,7 @@ def _id(self, tensor) -> int: with torch._C._DisablePythonDispatcher(): o = WeakIdRef(tensor) - def del_memo(): + def del_memo() -> None: self.tensor_memo.pop(o, None) weakref.finalize(tensor, del_memo) @@ -157,7 +157,7 @@ def __init__( record: Optional[dict[str, Any]] = None, log: Optional[dict[str, Any]] = None, stack: bool = False, - ): + ) -> None: self.call_depth = call_depth if stack: self.stack_trace = _get_stack_trace() @@ -207,7 +207,7 @@ def __init__( kwargs: dict, call_depth: int, stack: bool = False, - ): + ) -> None: super().__init__(call_depth, stack=stack) self.op = op self.args = args @@ -282,7 +282,7 @@ def __init__( transform_info_str, call_depth, stack=False, - ): + ) -> None: super().__init__(call_depth, stack=stack) self.arg = arg self.src_placement = src_placement @@ -334,7 +334,7 @@ def __iter__(self): class _NNModuleCall(_DebugCall): """Designates entering an nn.Module's forward method""" - def __init__(self, module_name: str, call_depth: int, stack: bool = False): + def __init__(self, module_name: str, call_depth: int, stack: bool = False) -> None: super().__init__(call_depth, stack=stack) self.module_name = module_name @@ -395,7 +395,7 @@ def __init__( record_stack_trace=False, record_output=False, record_ids=False, - ): + ) -> None: super().__init__() import torch.distributed.tensor # noqa: F401 @@ -440,13 +440,13 @@ def __init__( self.reset() - def reset(self): + def reset(self) -> None: self.operators = [] self.call_depth = 0 self._tensor_memo = TensorIdTracker() self._output_info: dict[int, object] = {} - def _track_op_output(self, op_index, result): + def _track_op_output(self, op_index, result) -> None: """Assign IDs to output tensors and store in output_info""" # self._track_tensor_ids(result) self._output_info[op_index] = result @@ -455,10 +455,10 @@ def _track_op_output(self, op_index, result): # will force torch.compile to always use the “eager” backend # With this, DebugMode will not take effect on torch.compile @classmethod - def ignore_compile_internals(cls): + def ignore_compile_internals(cls) -> bool: return True - def _record_call(self, call): + def _record_call(self, call) -> None: if not self.store_original_args: call.stringify_args( self.record_tensor_attributes, @@ -466,7 +466,7 @@ def _record_call(self, call): ) self.operators.append(call) - def _record_call_output(self, call, output): + def _record_call_output(self, call, output) -> None: if not self.record_output: return call.stringify_output( @@ -562,19 +562,19 @@ def __exit__(self, *args): if self.record_stack_trace: self.anomaly_for_traces.__exit__(*args) - def module_tracker_setup(self): + def module_tracker_setup(self) -> None: from torch.distributed._tools.mod_tracker import ModTracker self.module_tracker = ModTracker() # module pre-fw hook: record module call - def pre_fw_hook(module, input): + def pre_fw_hook(module, input) -> None: fqn = self.module_tracker._get_mod_name(module) # type: ignore[attribute, union-attr] self.operators.append(_NNModuleCall(fqn, self.call_depth + 1)) self.call_depth += 1 # module post-fw hook: decrement call depth - def post_fw_hook(module, input, output): + def post_fw_hook(module, input, output) -> None: self.call_depth -= 1 self.module_tracker.register_user_hooks(pre_fw_hook, post_fw_hook) diff --git a/torch/utils/_device.py b/torch/utils/_device.py index 2780218e03eef..e7e44719e0c57 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -59,7 +59,7 @@ def _device_constructors(): # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): - def __init__(self, device): + def __init__(self, device) -> None: # pyrefly: ignore [read-only] self.device = torch.device(device) diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index eea7310222394..fdb9a914bf64e 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -24,7 +24,7 @@ class OrderedSet(MutableSet[T], Reversible[T]): __slots__ = ("_dict",) - def __init__(self, iterable: Optional[Iterable[T]] = None): + def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: self._dict = dict.fromkeys(iterable, None) if iterable is not None else {} @staticmethod diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 52be3280c9c39..b07b7a4ec6fe5 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -86,7 +86,7 @@ class TorchDispatchMode: # Mode authors can implement how the mode interacts with higher order operators. supports_higher_order_operators = False - def __init__(self, _dispatch_key=None): + def __init__(self, _dispatch_key=None) -> None: if _dispatch_key is not None: if not isinstance(_dispatch_key, torch._C.DispatchKey): raise AssertionError("_dispatch_key must be a torch._C.DispatchKey") @@ -98,7 +98,7 @@ def __init__(self, _dispatch_key=None): deque() ) - def _lazy_init_old_dispatch_mode_flags(self): + def _lazy_init_old_dispatch_mode_flags(self) -> None: if not hasattr(self, "old_dispatch_mode_flags"): self.old_dispatch_mode_flags: deque[bool] = deque() # type: ignore[no-redef] @@ -171,11 +171,11 @@ def push(cls, *args, **kwargs): return instance @classmethod - def is_infra_mode(cls): + def is_infra_mode(cls) -> bool: return False @classmethod - def ignore_compile_internals(cls): + def ignore_compile_internals(cls) -> bool: """Ignore operators that are compiled via torch.compile. If ``True``, then this TorchDispatchMode ignores operators that @@ -287,7 +287,7 @@ def _get_current_dispatch_mode_stack() -> list[TorchDispatchMode]: return [_get_dispatch_stack_at(i) for i in range(stack_len)] -def _push_mode(mode: TorchDispatchMode): +def _push_mode(mode: TorchDispatchMode) -> None: k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None if k is not None and k != torch._C.DispatchKey.PreDispatch: raise AssertionError( @@ -544,7 +544,7 @@ def transform_subclass(t, callback, outer_size=None, outer_stride=None): return sub -def _correct_storage_aliasing(func, schema_info, args, outs): +def _correct_storage_aliasing(func, schema_info, args, outs) -> None: """ Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema), and the inputs/outputs to the OpOverload, @@ -563,7 +563,7 @@ def _correct_storage_aliasing(func, schema_info, args, outs): if not isinstance(outs, (list, tuple)): raise AssertionError(f"outs must be a list or tuple, got {type(args)}") - def alias_non_inplace_storage(arg, ret): + def alias_non_inplace_storage(arg, ret) -> None: # This is hopefully a reasonable assert: # subclasses that rely on this API for output aliasing # should always return wrapper tensor subclasses for us to manually alias. diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 024cd93b35788..47cf07552b2cf 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -81,7 +81,7 @@ def __init__( sample_tags: Optional[list[str]] = None, stack_max_len: int = 127, async_stack_max_len: int = 127, - ): + ) -> None: self.stop_at_error = stop_at_error self.max_profile_duration_sec = max_profile_duration_sec self.sample_each = sample_each diff --git a/torch/utils/_strobelight/examples/cli_function_profiler_example.py b/torch/utils/_strobelight/examples/cli_function_profiler_example.py index 322cd321199a1..fb957da009279 100644 --- a/torch/utils/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/utils/_strobelight/examples/cli_function_profiler_example.py @@ -14,7 +14,7 @@ def fn(x, y, z): # use decorator with default profiler or optional profile arguments. @strobelight(sample_each=10000, stop_at_error=False) @torch.compile() - def work(): + def work() -> None: for _ in range(10): torch._dynamo.reset() for j in range(5): @@ -27,7 +27,7 @@ def work(): profiler = StrobelightCLIFunctionProfiler(stop_at_error=False) @strobelight(profiler, sample_tags=["something", "another"]) - def work2(): + def work2() -> None: sum = 0 for _ in range(100000000): sum += 1 # noqa: SIM113 diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 297d7f4eec9a8..425344bda17ef 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -471,7 +471,7 @@ def _eval_is_nonnegative(self) -> Optional[bool]: def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] - def _ccode(self, printer): + def _ccode(self, printer) -> str: # pyrefly: ignore [missing-attribute] p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) # pyrefly: ignore [missing-attribute] @@ -558,7 +558,7 @@ def eval(cls, number): if isinstance(number, sympy.Number): return sympy.Integer(math.ceil(float(number))) - def _ccode(self, printer): + def _ccode(self, printer) -> str: # pyrefly: ignore [missing-attribute] number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5) return f"ceil({number})" @@ -1164,7 +1164,7 @@ def eval(cls, base, divisor): if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return sympy.Float(int(base) / int(divisor)) - def _ccode(self, printer): + def _ccode(self, printer) -> str: # pyrefly: ignore [missing-attribute] base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) # pyrefly: ignore [missing-attribute] @@ -1331,11 +1331,11 @@ class Identity(sympy.Function): precedence = 10 - def __repr__(self): # type: ignore[override] + def __repr__(self) -> str: # type: ignore[override] # pyrefly: ignore [missing-attribute] return f"Identity({self.args[0]})" - def _sympystr(self, printer): + def _sympystr(self, printer) -> str: """Controls how sympy's StrPrinter prints this""" # pyrefly: ignore [missing-attribute] return f"({printer.doprint(self.args[0])})" diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index f675de25ad8a7..8b08e01d8e52b 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -42,7 +42,7 @@ class IntInfinity(Number, metaclass=Singleton): def __new__(cls): return AtomicExpr.__new__(cls) - def _sympystr(self, printer): + def _sympystr(self, printer) -> str: return "int_oo" def _eval_subs(self, old, new): @@ -237,7 +237,7 @@ def _eval_subs(self, old, new): if self == old: return new - def _sympystr(self, printer): + def _sympystr(self, printer) -> str: return "-int_oo" """ diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py index 0bac76121f8b6..57d5615e55271 100644 --- a/torch/utils/_sympy/singleton_int.py +++ b/torch/utils/_sympy/singleton_int.py @@ -17,7 +17,7 @@ def __new__(cls, *args, coeff=None, **kwargs): # The semantics of this class should match that of NestedIntSymNodeImpl in # c10/core/NestedIntSymNodeImpl.h - def __init__(self, val, *, coeff=1): + def __init__(self, val, *, coeff=1) -> None: self._val = val self._coeff = coeff super().__init__() diff --git a/torch/utils/_thunk.py b/torch/utils/_thunk.py index 28689f2f76f18..a332babfdf4ce 100644 --- a/torch/utils/_thunk.py +++ b/torch/utils/_thunk.py @@ -17,7 +17,7 @@ class Thunk(Generic[R]): __slots__ = ["f", "r"] - def __init__(self, f: Callable[[], R]): + def __init__(self, f: Callable[[], R]) -> None: self.f = f self.r = None diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 21fadb297be80..39a302ea5ca25 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -144,7 +144,7 @@ def shorten_filename(fn, *, base=None): return fn[len(prefix) + 1 :] -def format_frame(frame, *, base=None, line=False): +def format_frame(frame, *, base=None, line=False) -> str: """ Format a FrameSummary in a short way, without printing full absolute path or code. @@ -164,11 +164,11 @@ def format_traceback_short(tb): class CapturedTraceback: __slots__ = ["tb", "skip"] - def __init__(self, tb, skip=0): + def __init__(self, tb, skip=0) -> None: self.tb = tb self.skip = skip - def cleanup(self): + def cleanup(self) -> None: self.tb = None def summary(self): diff --git a/torch/utils/_zip.py b/torch/utils/_zip.py index b159b61de06aa..5dd98e43c4a77 100644 --- a/torch/utils/_zip.py +++ b/torch/utils/_zip.py @@ -36,7 +36,7 @@ def remove_prefix(text, prefix): return text -def write_to_zip(file_path, strip_file_path, zf, prepend_str=""): +def write_to_zip(file_path, strip_file_path, zf, prepend_str="") -> None: stripped_file_path = prepend_str + remove_prefix(file_path, strip_file_dir + "/") path = Path(stripped_file_path) if path.name in DENY_LIST: diff --git a/torch/utils/backcompat/__init__.py b/torch/utils/backcompat/__init__.py index a8413b656e906..f6ec989be1e07 100644 --- a/torch/utils/backcompat/__init__.py +++ b/torch/utils/backcompat/__init__.py @@ -8,11 +8,11 @@ class Warning: - def __init__(self, setter, getter): + def __init__(self, setter, getter) -> None: self.setter = setter self.getter = getter - def set_enabled(self, value): + def set_enabled(self, value) -> None: self.setter(value) def get_enabled(self): diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 67e9a7311a09d..b31eb49a60601 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -82,7 +82,7 @@ def rename_privateuse1_backend(backend_name: str) -> None: _privateuse1_backend_name = backend_name -def _check_register_once(module, attr): +def _check_register_once(module, attr) -> None: if hasattr(module, attr): raise RuntimeError( f"The custom device module of {module} has already been registered with {attr}" @@ -448,33 +448,33 @@ def func_name(*args, **kwargs): class _DummyBackendModule: - def is_initialized(self): + def is_initialized(self) -> bool: return True - def is_available(self): + def is_available(self) -> bool: return True - def current_device(self): + def current_device(self) -> int: return 0 - def _is_in_bad_fork(self): + def _is_in_bad_fork(self) -> bool: return False - def manual_seed_all(self, seed: int): + def manual_seed_all(self, seed: int) -> None: pass - def device_count(self): + def device_count(self) -> int: return 1 class _DummyPrivateUse1Hook(torch._C._acc.PrivateUse1Hooks): - def is_available(self): + def is_available(self) -> bool: return True - def has_primary_context(self, dev_id): + def has_primary_context(self, dev_id) -> bool: return True - def is_built(self): + def is_built(self) -> bool: return True @@ -485,7 +485,7 @@ def type_(self): def _setup_privateuseone_for_python_backend( rename=None, backend_module=None, hook=None, device_guard=None -): +) -> None: """This function will prepare the PrivateUse1 dispatch key to be used as a python backend. WARNING: this API is experimental and might change without notice. diff --git a/torch/utils/benchmark/examples/compare.py b/torch/utils/benchmark/examples/compare.py index 5d797a5b0a2bf..1c266e7cf9a6e 100644 --- a/torch/utils/benchmark/examples/compare.py +++ b/torch/utils/benchmark/examples/compare.py @@ -20,7 +20,7 @@ class FauxTorch: writing serialized measurements, but this simplifies that model to make the example clearer. """ - def __init__(self, real_torch, extra_ns_per_element): + def __init__(self, real_torch, extra_ns_per_element) -> None: self._real_torch = real_torch self._extra_ns_per_element = extra_ns_per_element @@ -45,7 +45,7 @@ def matmul(self, *args, **kwargs): return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) -def main(): +def main() -> None: tasks = [ ("add", "add", "torch.add(x, y)"), ("add", "add (extra +0)", "torch.add(x, y + zero)"), diff --git a/torch/utils/benchmark/examples/fuzzer.py b/torch/utils/benchmark/examples/fuzzer.py index ee2c9f9c04ed1..80a4e733928d8 100644 --- a/torch/utils/benchmark/examples/fuzzer.py +++ b/torch/utils/benchmark/examples/fuzzer.py @@ -9,7 +9,7 @@ import torch.utils.benchmark as benchmark_utils -def main(): +def main() -> None: add_fuzzer = benchmark_utils.Fuzzer( parameters=[ [ diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py index 8a76331d3404f..f65599ee18a4f 100644 --- a/torch/utils/benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -16,7 +16,7 @@ _MEASURE_TIME = 1.0 -def assert_dicts_equal(dict_0, dict_1): +def assert_dicts_equal(dict_0, dict_1) -> None: """Builtin dict comparison will not compare numpy arrays. e.g. x = {"a": np.ones((2, 1))} @@ -28,7 +28,7 @@ def assert_dicts_equal(dict_0, dict_1): raise AssertionError("dict values differ for keys other than 'dtype'") -def run(n, stmt, fuzzer_cls): +def run(n, stmt, fuzzer_cls) -> None: float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n) raw_results = [] @@ -97,7 +97,7 @@ def run(n, stmt, fuzzer_cls): print(spacer) -def main(): +def main() -> None: run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer) run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer) run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer) diff --git a/torch/utils/benchmark/examples/sparse/compare.py b/torch/utils/benchmark/examples/sparse/compare.py index fa00fb1818cda..e61b0cc063469 100644 --- a/torch/utils/benchmark/examples/sparse/compare.py +++ b/torch/utils/benchmark/examples/sparse/compare.py @@ -19,7 +19,7 @@ class FauxTorch: writing serialized measurements, but this simplifies that model to make the example clearer. """ - def __init__(self, real_torch, extra_ns_per_element): + def __init__(self, real_torch, extra_ns_per_element) -> None: self._real_torch = real_torch self._extra_ns_per_element = extra_ns_per_element @@ -28,7 +28,7 @@ def sparse(self): return self.Sparse(self._real_torch, self._extra_ns_per_element) class Sparse: - def __init__(self, real_torch, extra_ns_per_element): + def __init__(self, real_torch, extra_ns_per_element) -> None: self._real_torch = real_torch self._extra_ns_per_element = extra_ns_per_element @@ -73,7 +73,7 @@ def gen_sparse(size, density, dtype, device='cpu'): indices, values = generate_coo_data(size, sparse_dim, nnz, dtype, device) return torch.sparse_coo_tensor(indices, values, size, dtype=dtype, device=device) -def main(): +def main() -> None: tasks = [ ("matmul", "x @ y", "torch.sparse.mm(x, y)"), ("matmul", "x @ y + 0", "torch.sparse.mm(x, y) + zero"), diff --git a/torch/utils/benchmark/examples/sparse/fuzzer.py b/torch/utils/benchmark/examples/sparse/fuzzer.py index a5aac22179d86..c2a5bc2f112bb 100644 --- a/torch/utils/benchmark/examples/sparse/fuzzer.py +++ b/torch/utils/benchmark/examples/sparse/fuzzer.py @@ -8,7 +8,7 @@ import torch.utils.benchmark as benchmark_utils -def main(): +def main() -> None: add_fuzzer = benchmark_utils.Fuzzer( parameters=[ [ diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py index b574b0223d489..20a7429d551b7 100644 --- a/torch/utils/benchmark/examples/sparse/op_benchmark.py +++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py @@ -14,7 +14,7 @@ _MEASURE_TIME = 1.0 -def assert_dicts_equal(dict_0, dict_1): +def assert_dicts_equal(dict_0, dict_1) -> None: """Builtin dict comparison will not compare numpy arrays. e.g. x = {"a": np.ones((2, 1))} @@ -25,7 +25,7 @@ def assert_dicts_equal(dict_0, dict_1): if all(np.all(v != dict_1[k]) for k, v in dict_0.items() if k != "dtype"): raise AssertionError("dict values differ for keys other than 'dtype'") -def run(n, stmt, fuzzer_cls): +def run(n, stmt, fuzzer_cls) -> None: float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) double_iter = fuzzer_cls(seed=0, dtype=torch.float64).take(n) raw_results = [] @@ -92,7 +92,7 @@ def run(n, stmt, fuzzer_cls): print(spacer) -def main(): +def main() -> None: run(n=100, stmt="torch.sparse.sum(x, dim=0)", fuzzer_cls=UnaryOpSparseFuzzer) run(n=100, stmt="torch.sparse.softmax(x, dim=0)", fuzzer_cls=UnaryOpSparseFuzzer) run(n=100, stmt="x + y", fuzzer_cls=BinaryOpSparseFuzzer) diff --git a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py index a3c8cbe5b12c2..81a33c34bc822 100644 --- a/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py +++ b/torch/utils/benchmark/examples/spectral_ops_fuzz_test.py @@ -63,7 +63,7 @@ def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, de BENCHMARK_NAMES = [b.name for b in BENCHMARKS] DEVICE_NAMES = ['cpu', 'cuda'] -def _output_csv(file, results): +def _output_csv(file, results) -> None: file.write('benchmark,device,num_threads,numel,shape,contiguous,dim,mean (us),median (us),iqr (us)\n') for measurement in results: metadata = measurement.metadata diff --git a/torch/utils/benchmark/op_fuzzers/binary.py b/torch/utils/benchmark/op_fuzzers/binary.py index 75f394179b3e0..e53c310111bec 100644 --- a/torch/utils/benchmark/op_fuzzers/binary.py +++ b/torch/utils/benchmark/op_fuzzers/binary.py @@ -14,7 +14,7 @@ class BinaryOpFuzzer(Fuzzer): - def __init__(self, seed, dtype=torch.float32, cuda=False): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: super().__init__( parameters=[ # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) diff --git a/torch/utils/benchmark/op_fuzzers/sparse_binary.py b/torch/utils/benchmark/op_fuzzers/sparse_binary.py index 014361877dea1..8e6269464e0d5 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_binary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_binary.py @@ -14,7 +14,7 @@ class BinaryOpSparseFuzzer(Fuzzer): - def __init__(self, seed, dtype=torch.float32, cuda=False): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: super().__init__( parameters=[ # Dimensionality of x and y. (e.g. 1D, 2D, or 3D.) diff --git a/torch/utils/benchmark/op_fuzzers/spectral.py b/torch/utils/benchmark/op_fuzzers/spectral.py index 2b9e92d7a2c7b..c324e338dca5d 100644 --- a/torch/utils/benchmark/op_fuzzers/spectral.py +++ b/torch/utils/benchmark/op_fuzzers/spectral.py @@ -29,7 +29,7 @@ def power_range(upper_bound, base): class SpectralOpFuzzer(benchmark.Fuzzer): def __init__(self, *, seed: int, dtype=torch.float64, - cuda: bool = False, probability_regular: float = 1.0): + cuda: bool = False, probability_regular: float = 1.0) -> None: super().__init__( parameters=[ # Dimensionality of x. (e.g. 1D, 2D, or 3D.) diff --git a/torch/utils/benchmark/op_fuzzers/unary.py b/torch/utils/benchmark/op_fuzzers/unary.py index e780b421f24c8..6008adfe45921 100644 --- a/torch/utils/benchmark/op_fuzzers/unary.py +++ b/torch/utils/benchmark/op_fuzzers/unary.py @@ -14,7 +14,7 @@ class UnaryOpFuzzer(Fuzzer): - def __init__(self, seed, dtype=torch.float32, cuda=False): + def __init__(self, seed, dtype=torch.float32, cuda=False) -> None: super().__init__( parameters=[ # Dimensionality of x. (e.g. 1D, 2D, or 3D.) diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 21a83926a2e82..e9a0966c6e966 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -34,7 +34,7 @@ def __init__( time_unit: str, trim_significant_figures: bool, highlight_warnings: bool, - ): + ) -> None: self._grouped_results = grouped_results self._flat_results = [*it.chain.from_iterable(grouped_results)] self._time_scale = time_scale @@ -79,7 +79,7 @@ def optional_min(seq): class _Row: def __init__(self, results, row_group, render_env, env_str_len, - row_name_str_len, time_scale, colorize, num_threads=None): + row_name_str_len, time_scale, colorize, num_threads=None) -> None: super().__init__() self._results = results self._row_group = row_group @@ -91,7 +91,7 @@ def __init__(self, results, row_group, render_env, env_str_len, self._columns: tuple[_Column, ...] = () self._num_threads = num_threads - def register_columns(self, columns: tuple[_Column, ...]): + def register_columns(self, columns: tuple[_Column, ...]) -> None: self._columns = columns def as_column_strings(self): @@ -156,7 +156,7 @@ def __init__( colorize: Colorize, trim_significant_figures: bool, highlight_warnings: bool - ): + ) -> None: if len({r.label for r in results}) != 1: raise AssertionError("All results must share the same label") @@ -283,17 +283,17 @@ class Compare: Args: results: List of Measurement to display. """ - def __init__(self, results: list[common.Measurement]): + def __init__(self, results: list[common.Measurement]) -> None: self._results: list[common.Measurement] = [] self.extend_results(results) self._trim_significant_figures = False self._colorize = Colorize.NONE self._highlight_warnings = False - def __str__(self): + def __str__(self) -> str: return "\n".join(self._render()) - def extend_results(self, results): + def extend_results(self, results) -> None: """Append results to already stored ones. All added results must be instances of ``Measurement``. @@ -305,22 +305,22 @@ def extend_results(self, results): ) self._results.extend(results) - def trim_significant_figures(self): + def trim_significant_figures(self) -> None: """Enables trimming of significant figures when building the formatted table.""" self._trim_significant_figures = True - def colorize(self, rowwise=False): + def colorize(self, rowwise=False) -> None: """Colorize formatted table. Colorize columnwise by default. """ self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE - def highlight_warnings(self): + def highlight_warnings(self) -> None: """Enables warning highlighting when building formatted table.""" self._highlight_warnings = True - def print(self): + def print(self) -> None: """Print formatted table""" print(str(self)) diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index 777120c811057..d8881354ddaf2 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -25,7 +25,7 @@ print("tabulate is not installed, please pip install tabulate to use this utility") if HAS_TABULATE: - def _enable_tensor_cores(): + def _enable_tensor_cores() -> None: global _warned_tensor_cores if torch.cuda.is_available(): @@ -36,7 +36,7 @@ def _enable_tensor_cores(): print("we will enable it automatically by setting `torch.set_float32_matmul_precision('high')`") _warned_tensor_cores = True - def _disable_tensor_cores(): + def _disable_tensor_cores() -> None: torch.set_float32_matmul_precision(_default_float_32_precision) def bench_loop( diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index f343722ef686d..06f37bd8f3a35 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -29,7 +29,7 @@ def __init__( maxval: Optional[Union[int, float]] = None, distribution: Optional[Union[str, dict[Any, float]]] = None, strict: bool = False, - ): + ) -> None: """ Args: name: @@ -159,10 +159,10 @@ class ParameterAlias: Chains of alias' are allowed, but may not contain cycles. """ - def __init__(self, alias_to): + def __init__(self, alias_to) -> None: self.alias_to = alias_to - def __repr__(self): + def __repr__(self) -> str: return f"ParameterAlias[alias_to: {self.alias_to}]" @@ -199,7 +199,7 @@ def __init__( dtype=torch.float32, cuda=False, tensor_constructor: Optional[Callable] = None - ): + ) -> None: """ Args: name: @@ -329,7 +329,7 @@ def resolve(values, dim): allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True)) return size, steps, allocation_size - def satisfies_constraints(self, params): + def satisfies_constraints(self, params) -> bool: size, _, allocation_size = self._get_size_and_steps(params) # Product is computed in Python to avoid integer overflow. num_elements = prod(size) @@ -357,7 +357,7 @@ def __init__( tensors: list[Union[FuzzedTensor, list[FuzzedTensor]]], constraints: Optional[list[Callable]] = None, seed: Optional[int] = None - ): + ) -> None: """ Args: parameters: diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index cd84900c5b438..49afb5ea9ad06 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -19,7 +19,7 @@ def __init__( coalesced: Optional[str] = None, dtype=torch.float32, cuda=False - ): + ) -> None: """ Args: name: diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index 9080f82721600..ef9c1936b3570 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -311,7 +311,7 @@ class CopyIfCallgrind: See `GlobalsBridge` for why this matters. """ - def __init__(self, value: Any, *, setup: Optional[str] = None): + def __init__(self, value: Any, *, setup: Optional[str] = None) -> None: for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): if any(isinstance(value, t) for t in supported_types): self._value: Any = value diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index a643314f3b9cd..1f5f3f0f60575 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -899,7 +899,7 @@ def get_pretty_env_info(): return pretty_str(get_env_info()) -def main(): +def main() -> None: print("Collecting environment information...") output = get_pretty_env_info() print(output) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 235b7e104c702..fc16c38b8e3e4 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -310,7 +310,7 @@ def _get_sycl_arch_list(): # If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled # for default spir64 target and avoid device specific compilations entirely. Further, kernels # will be JIT compiled at runtime. -def _append_sycl_targets_if_missing(cflags): +def _append_sycl_targets_if_missing(cflags) -> None: if any(flag.startswith('-fsycl-targets=') for flag in cflags): # do nothing: user has manually specified sycl targets return @@ -367,7 +367,7 @@ def _accepted_compilers_for_platform() -> list[str]: # gnu-c++ and gnu-cc are the conda gcc compilers return ['clang++', 'clang'] if IS_MACOS else ['g++', 'gcc', 'gnu-c++', 'gnu-cc', 'clang++', 'clang'] -def _maybe_write(filename, new_content): +def _maybe_write(filename, new_content) -> None: r''' Equivalent to writing the content into the file but will not touch the file if it already had the right content (to avoid triggering recompile). @@ -559,7 +559,7 @@ def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> N # Specify Visual Studio C runtime library for hipcc -def _set_hipcc_runtime_lib(is_standalone, debug): +def _set_hipcc_runtime_lib(is_standalone, debug) -> None: if is_standalone: if debug: COMMON_HIP_FLAGS.append('-fms-runtime-lib=static_dbg') @@ -571,7 +571,7 @@ def _set_hipcc_runtime_lib(is_standalone, debug): else: COMMON_HIP_FLAGS.append('-fms-runtime-lib=dll') -def _append_sycl_std_if_no_std_present(cflags): +def _append_sycl_std_if_no_std_present(cflags) -> None: if not any(flag.startswith('-sycl-std=') for flag in cflags): cflags.append('-sycl-std=2020') @@ -616,7 +616,7 @@ class BuildExtension(build_ext): def with_options(cls, **options): """Return a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options.""" class cls_with_options(cls): # type: ignore[misc, valid-type] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: kwargs.update(options) super().__init__(*args, **kwargs) @@ -742,7 +742,7 @@ def unix_cuda_flags(cflags): return cflags - def convert_to_absolute_paths_inplace(paths): + def convert_to_absolute_paths_inplace(paths) -> None: # Helper function. See Note [Absolute include_dirs] if paths is not None: for i in range(len(paths)): @@ -1123,7 +1123,7 @@ def _check_abi(self) -> tuple[str, TorchVersion]: raise UserWarning(msg) return compiler, version - def _add_compile_flag(self, extension, flag): + def _add_compile_flag(self, extension, flag) -> None: extension.extra_compile_args = copy.deepcopy(extension.extra_compile_args) if isinstance(extension.extra_compile_args, dict): for args in extension.extra_compile_args.values(): @@ -1133,7 +1133,7 @@ def _add_compile_flag(self, extension, flag): # Simple hipify, replace the first occurrence of CUDA with HIP # in flags starting with "-" and containing "CUDA", but exclude -I flags - def _hipify_compile_flags(self, extension): + def _hipify_compile_flags(self, extension) -> None: if isinstance(extension.extra_compile_args, dict) and 'nvcc' in extension.extra_compile_args: modified_flags = [] for flag in extension.extra_compile_args['nvcc']: @@ -1154,7 +1154,7 @@ def _hipify_compile_flags(self, extension): modified_flags.append(flag) extension.extra_compile_args['nvcc'] = modified_flags - def _define_torch_extension_name(self, extension): + def _define_torch_extension_name(self, extension) -> None: # pybind11 doesn't support dots in the names # so in order to support extensions in the packages # like torch._C, we take the last part of the string @@ -1733,7 +1733,7 @@ def load(name, def _get_pybind11_abi_build_flags() -> list[str]: return [] -def check_compiler_is_gcc(compiler): +def check_compiler_is_gcc(compiler) -> bool: if not IS_LINUX: return False @@ -1760,7 +1760,7 @@ def check_compiler_is_gcc(compiler): def _check_and_build_extension_h_precompiler_headers( extra_cflags, extra_include_paths, - is_standalone=False): + is_standalone=False) -> None: r''' Precompiled Headers(PCH) can pre-build the same headers and reduce build time for pytorch load_inline modules. GCC official manual: https://gcc.gnu.org/onlinedocs/gcc-4.0.4/gcc/Precompiled-Headers.html @@ -1821,7 +1821,7 @@ def check_pch_signature_in_file(file_path, signature): # check if string present in a file return signature == content - def _create_if_not_exist(path_dir): + def _create_if_not_exist(path_dir) -> None: if not os.path.exists(path_dir): try: Path(path_dir).mkdir(parents=True, exist_ok=True) @@ -1829,13 +1829,13 @@ def _create_if_not_exist(path_dir): if exc.errno != errno.EEXIST: raise RuntimeError(f"Fail to create path {path_dir}") from exc - def write_pch_signature_to_file(file_path, pch_sign): + def write_pch_signature_to_file(file_path, pch_sign) -> None: _create_if_not_exist(os.path.dirname(file_path)) with open(file_path, "w") as f: f.write(pch_sign) f.close() - def build_precompile_header(pch_cmd): + def build_precompile_header(pch_cmd) -> None: try: subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: @@ -1876,8 +1876,8 @@ def build_precompile_header(pch_cmd): build_precompile_header(pch_cmd) write_pch_signature_to_file(head_file_signature, pch_sign) -def remove_extension_h_precompiler_headers(): - def _remove_if_file_exists(path_file): +def remove_extension_h_precompiler_headers() -> None: + def _remove_if_file_exists(path_file) -> None: if os.path.exists(path_file): os.remove(path_file) @@ -2313,7 +2313,7 @@ def _write_ninja_file_and_build_library( error_prefix=f"Error building extension '{name}'") -def is_ninja_available(): +def is_ninja_available() -> bool: """Return ``True`` if the `ninja `_ build system is available on the system, ``False`` otherwise.""" try: subprocess.check_output(['ninja', '--version']) @@ -2323,7 +2323,7 @@ def is_ninja_available(): return True -def verify_ninja_availability(): +def verify_ninja_availability() -> None: """Raise ``RuntimeError`` if `ninja `_ build system is not available on the system, does nothing otherwise.""" if not is_ninja_available(): raise RuntimeError("Ninja is required to load C++ extensions (pip install ninja to get it)") diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index 223962fc04ba9..cd9722b04e588 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -15,7 +15,7 @@ from . import MP_STATUS_CHECK_INTERVAL -def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): +def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device) -> None: # This setting is thread local, and prevents the copy in pin_memory from # consuming all CPU cores. torch.set_num_threads(1) @@ -23,7 +23,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): torch.multiprocessing._set_thread_name("pt_data_pin") torch.accelerator.set_device_index(device_id) - def do_one_step(): + def do_one_step() -> None: try: r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: diff --git a/torch/utils/data/_utils/signal_handling.py b/torch/utils/data/_utils/signal_handling.py index 33e1dd021e975..abff09bc40819 100644 --- a/torch/utils/data/_utils/signal_handling.py +++ b/torch/utils/data/_utils/signal_handling.py @@ -51,7 +51,7 @@ handler needs to be set for all DataLoaders in a process.""" -def _set_SIGCHLD_handler(): +def _set_SIGCHLD_handler() -> None: # Windows doesn't support SIGCHLD handler if IS_WINDOWS: return @@ -67,7 +67,7 @@ def _set_SIGCHLD_handler(): # no-op. previous_handler = None - def handler(signum, frame): + def handler(signum, frame) -> None: # This following call uses `waitid` with WNOHANG from C side. Therefore, # Python can still get and update the process status successfully. _error_if_any_worker_fails() diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index 5e61912dc6e77..c2d9294db86d9 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -49,7 +49,7 @@ def __init__(self) -> None: self.manager_dead = False - def is_alive(self): + def is_alive(self) -> bool: if not self.manager_dead: # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx self.manager_dead = ( @@ -64,7 +64,7 @@ def __init__(self) -> None: self.manager_pid = os.getppid() self.manager_dead = False - def is_alive(self): + def is_alive(self) -> bool: if not self.manager_dead: self.manager_dead = os.getppid() != self.manager_pid return not self.manager_dead @@ -80,20 +80,20 @@ class WorkerInfo: dataset: "Dataset" __initialized = False - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: for k, v in kwargs.items(): setattr(self, k, v) self.__keys = tuple(kwargs.keys()) self.__initialized = True - def __setattr__(self, key, val): + def __setattr__(self, key, val) -> None: if self.__initialized: raise RuntimeError( f"Cannot assign attributes to {self.__class__.__name__} objects" ) return super().__setattr__(key, val) - def __repr__(self): + def __repr__(self) -> str: items = [f"{k}={getattr(self, k)}" for k in self.__keys] return f"{self.__class__.__name__}({', '.join(items)})" @@ -240,7 +240,7 @@ def _worker_loop( num_workers, persistent_workers, shared_seed, -): +) -> None: # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the # logic of this function. diff --git a/torch/utils/data/backward_compatibility.py b/torch/utils/data/backward_compatibility.py index e8f1c4e30ef72..5b928aea69fa7 100644 --- a/torch/utils/data/backward_compatibility.py +++ b/torch/utils/data/backward_compatibility.py @@ -7,5 +7,5 @@ "as `DataLoader` automatically applies sharding in every worker", category=FutureWarning, ) -def worker_init_fn(worker_id): +def worker_init_fn(worker_id) -> None: pass diff --git a/torch/utils/data/dataframes_pipes.ipynb b/torch/utils/data/dataframes_pipes.ipynb index bc4abeba15b33..4d65abe0e7ef9 100644 --- a/torch/utils/data/dataframes_pipes.ipynb +++ b/torch/utils/data/dataframes_pipes.ipynb @@ -27,7 +27,7 @@ "source": [ "# Example IterDataPipe\n", "class ExampleIterPipe(IterDataPipe):\n", - " def __init__(self, range = 20):\n", + " def __init__(self, range = 20) -> None:\n", " self.range = range\n", " def __iter__(self):\n", " yield from self.range\n", diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index ae42f75885c1d..2683616804749 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -52,7 +52,7 @@ def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False): return output_string -def _gen_invalid_iterdatapipe_msg(datapipe): +def _gen_invalid_iterdatapipe_msg(datapipe) -> str: return ( "This iterator has been invalidated because another iterator has been created " f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n" @@ -119,7 +119,7 @@ def _set_datapipe_valid_iterator_id(datapipe): return datapipe._valid_iterator_id -def hook_iterator(namespace): +def hook_iterator(namespace) -> None: r""" Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. @@ -141,7 +141,7 @@ class IteratorDecorator: Those `__iter__` method commonly returns `self` but not necessarily. """ - def __init__(self, iterator, datapipe, iterator_id, has_next_method): + def __init__(self, iterator, datapipe, iterator_id, has_next_method) -> None: self.iterator = iterator self.datapipe = datapipe self.iterator_id = iterator_id diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 5392d71bce804..e198aa16caa66 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -235,10 +235,10 @@ def issubinstance(data, data_type): class _DataPipeType: r"""Save type annotation in `param`.""" - def __init__(self, param): + def __init__(self, param) -> None: self.param = param - def __repr__(self): + def __repr__(self) -> str: return _type_repr(self.param) def __eq__(self, other): @@ -300,7 +300,7 @@ def __new__(cls, name, bases, namespace, **kwargs): ) return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] - def __init__(self, name, bases, namespace, **kwargs): + def __init__(self, name, bases, namespace, **kwargs) -> None: super().__init__(name, bases, namespace, **kwargs) # type: ignore[call-overload] # TODO: Fix isinstance bug @@ -388,7 +388,7 @@ def __new__(cls, name, bases, namespace, **kwargs): reset_func = namespace["reset"] @functools.wraps(reset_func) - def conditional_reset(*args, **kwargs): + def conditional_reset(*args, **kwargs) -> None: r""" Only execute DataPipe's `reset()` method if `_SnapshotState` is `Iterating` or `NotStarted`. @@ -413,7 +413,7 @@ def conditional_reset(*args, **kwargs): return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] -def _dp_init_subclass(sub_cls, *args, **kwargs): +def _dp_init_subclass(sub_cls, *args, **kwargs) -> None: # Add function for datapipe instance to reinforce the type sub_cls.reinforce_type = reinforce_type diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py index 4bbd2505b4b5f..410683bcfbd70 100644 --- a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -83,7 +83,7 @@ def get_df_wrapper(): return default_wrapper -def set_df_wrapper(wrapper): +def set_df_wrapper(wrapper) -> None: global default_wrapper default_wrapper = wrapper diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index 0526b472ad194..50c5a44dfd5f3 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -19,7 +19,7 @@ @functional_datapipe("_dataframes_as_tuples") class DataFramesAsTuplesPipe(IterDataPipe): - def __init__(self, source_datapipe): + def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe def __iter__(self): @@ -30,7 +30,7 @@ def __iter__(self): @functional_datapipe("_dataframes_per_row", enable_df_api_tracing=True) class PerRowDataFramesPipe(DFIterDataPipe): - def __init__(self, source_datapipe): + def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe def __iter__(self): @@ -42,7 +42,7 @@ def __iter__(self): @functional_datapipe("_dataframes_concat", enable_df_api_tracing=True) class ConcatDataFramesPipe(DFIterDataPipe): - def __init__(self, source_datapipe, batch=3): + def __init__(self, source_datapipe, batch=3) -> None: self.source_datapipe = source_datapipe self.n_batch = batch @@ -59,7 +59,7 @@ def __iter__(self): @functional_datapipe("_dataframes_shuffle", enable_df_api_tracing=True) class ShuffleDataFramesPipe(DFIterDataPipe): - def __init__(self, source_datapipe): + def __init__(self, source_datapipe) -> None: self.source_datapipe = source_datapipe def __iter__(self): @@ -84,7 +84,7 @@ def __iter__(self): @functional_datapipe("_dataframes_filter", enable_df_api_tracing=True) class FilterDataFramesPipe(DFIterDataPipe): - def __init__(self, source_datapipe, filter_fn): + def __init__(self, source_datapipe, filter_fn) -> None: self.source_datapipe = source_datapipe self.filter_fn = filter_fn @@ -113,7 +113,7 @@ def __iter__(self): @functional_datapipe("_to_dataframes_pipe", enable_df_api_tracing=True) class ExampleAggregateAsDataFrames(DFIterDataPipe): - def __init__(self, source_datapipe, dataframe_size=10, columns=None): + def __init__(self, source_datapipe, dataframe_size=10, columns=None) -> None: self.source_datapipe = source_datapipe self.columns = columns self.dataframe_size = dataframe_size diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index f0811ac81b616..4b3913bc82369 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -153,13 +153,13 @@ def __getattr__(self, attribute_name): ) @classmethod - def register_function(cls, function_name, function): + def register_function(cls, function_name, function) -> None: cls.functions[function_name] = function @classmethod def register_datapipe_as_function( cls, function_name, cls_to_register, enable_df_api_tracing=False - ): + ) -> None: if function_name in cls.functions: raise Exception( # noqa: TRY002 f"Unable to add DataPipe function name {function_name} as it is already taken" @@ -203,24 +203,24 @@ def __reduce_ex__(self, *args, **kwargs): return super().__reduce_ex__(*args, **kwargs) @classmethod - def set_getstate_hook(cls, hook_fn): + def set_getstate_hook(cls, hook_fn) -> None: if IterDataPipe.getstate_hook is not None and hook_fn is not None: raise RuntimeError("Attempt to override existing getstate_hook") IterDataPipe.getstate_hook = hook_fn @classmethod - def set_reduce_ex_hook(cls, hook_fn): + def set_reduce_ex_hook(cls, hook_fn) -> None: if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None: raise RuntimeError("Attempt to override existing reduce_ex_hook") IterDataPipe.reduce_ex_hook = hook_fn - def __repr__(self): + def __repr__(self) -> str: if self.repr_hook is not None: return self.repr_hook(self) # Instead of showing , return the class name return str(self.__class__.__qualname__) - def __str__(self): + def __str__(self) -> str: if self.str_hook is not None: return self.str_hook(self) # Instead of showing , return the class name @@ -242,7 +242,7 @@ def reset(self) -> None: class DFIterDataPipe(IterDataPipe): - def _is_dfpipe(self): + def _is_dfpipe(self) -> bool: return True @@ -301,11 +301,11 @@ def __getattr__(self, attribute_name): ) @classmethod - def register_function(cls, function_name, function): + def register_function(cls, function_name, function) -> None: cls.functions[function_name] = function @classmethod - def register_datapipe_as_function(cls, function_name, cls_to_register): + def register_datapipe_as_function(cls, function_name, cls_to_register) -> None: if function_name in cls.functions: raise Exception( # noqa: TRY002 f"Unable to add DataPipe function name {function_name} as it is already taken" @@ -342,24 +342,24 @@ def __reduce_ex__(self, *args, **kwargs): return super().__reduce_ex__(*args, **kwargs) @classmethod - def set_getstate_hook(cls, hook_fn): + def set_getstate_hook(cls, hook_fn) -> None: if MapDataPipe.getstate_hook is not None and hook_fn is not None: raise RuntimeError("Attempt to override existing getstate_hook") MapDataPipe.getstate_hook = hook_fn @classmethod - def set_reduce_ex_hook(cls, hook_fn): + def set_reduce_ex_hook(cls, hook_fn) -> None: if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None: raise RuntimeError("Attempt to override existing reduce_ex_hook") MapDataPipe.reduce_ex_hook = hook_fn - def __repr__(self): + def __repr__(self) -> str: if self.repr_hook is not None: return self.repr_hook(self) # Instead of showing , return the class name return str(self.__class__.__qualname__) - def __str__(self): + def __str__(self) -> str: if self.str_hook is not None: return self.str_hook(self) # Instead of showing , return the class name @@ -371,7 +371,7 @@ def __dir__(self): class _DataPipeSerializationWrapper: - def __init__(self, datapipe): + def __init__(self, datapipe) -> None: self._datapipe = datapipe def __getstate__(self): @@ -395,7 +395,7 @@ def __setstate__(self, state): else: self._datapipe = pickle.loads(value) - def __len__(self): + def __len__(self) -> int: try: return len(self._datapipe) except Exception as e: @@ -405,7 +405,7 @@ def __len__(self): class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): - def __init__(self, datapipe: IterDataPipe[_T_co]): + def __init__(self, datapipe: IterDataPipe[_T_co]) -> None: super().__init__(datapipe) # pyrefly: ignore [invalid-type-var] self._datapipe_iter: Optional[Iterator[_T_co]] = None diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index 9f16f6f4552d4..23fd20f602567 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -52,7 +52,7 @@ def gen_from_template( template_name: str, output_name: str, replacements: list[tuple[str, Any, int]], -): +) -> None: template_path = os.path.join(dir, template_name) output_path = os.path.join(dir, output_name) diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index ff76e995f0ad2..6b4f134ef917d 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -189,5 +189,5 @@ def __setstate__(self, state): self._rng = random.Random() self._rng.setstate(rng_state) - def __del__(self): + def __del__(self) -> None: self._buffer.clear() diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 6efaa8c3d8be9..4682a483170f5 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -46,7 +46,7 @@ class ConcaterIterDataPipe(IterDataPipe): datapipes: tuple[IterDataPipe] - def __init__(self, *datapipes: IterDataPipe): + def __init__(self, *datapipes: IterDataPipe) -> None: if len(datapipes) == 0: raise ValueError("Expected at least one DataPipe, but got nothing") if not all(isinstance(dp, IterDataPipe) for dp in datapipes): @@ -148,7 +148,7 @@ def __init__( num_instances: int, buffer_size: int = 1000, copy: Optional[Literal["shallow", "deep"]] = None, - ): + ) -> None: self.main_datapipe = datapipe self._datapipe_iterator: Optional[Iterator[Any]] = None self.num_instances = num_instances @@ -180,7 +180,7 @@ def __init__( self.end_ptr: Optional[int] = None # The index to stop child self._child_stop: list[bool] = [True for _ in range(num_instances)] - def __len__(self): + def __len__(self) -> int: # pyrefly: ignore [bad-argument-type] return len(self.main_datapipe) @@ -283,12 +283,12 @@ def __setstate__(self, state): self.end_ptr = None self._child_stop = [True for _ in range(self.num_instances)] - def _cleanup(self): + def _cleanup(self) -> None: while self.buffer: d = self.buffer.popleft() StreamWrapper.close_streams(d) - def __del__(self): + def __del__(self) -> None: self._cleanup() @@ -324,7 +324,7 @@ class _ChildDataPipe(IterDataPipe): _is_child_datapipe: bool = True - def __init__(self, main_datapipe: IterDataPipe, instance_id: int): + def __init__(self, main_datapipe: IterDataPipe, instance_id: int) -> None: if not isinstance(main_datapipe, _ContainerTemplate): raise AssertionError("main_datapipe must implement _ContainerTemplate") @@ -337,7 +337,7 @@ def __iter__(self): # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called return self.main_datapipe.get_next_element_by_instance(self.instance_id) - def __len__(self): + def __len__(self) -> int: return self.main_datapipe.get_length_by_instance(self.instance_id) # This method is called by `hook_iterator` in `_typing.py`. @@ -455,7 +455,7 @@ def __init__( classifier_fn: Callable[[_T_co], Optional[int]], drop_none: bool, buffer_size: int, - ): + ) -> None: # pyrefly: ignore [invalid-type-var] self.main_datapipe = datapipe self._datapipe_iterator: Optional[Iterator[Any]] = None @@ -582,7 +582,7 @@ def __setstate__(self, state): self._child_stop = [True for _ in range(self.num_instances)] self.main_datapipe_exhausted = False - def _cleanup(self, instance_id: Optional[int] = None): + def _cleanup(self, instance_id: Optional[int] = None) -> None: ids = ( range(self.num_instances) if instance_id is None @@ -596,7 +596,7 @@ def _cleanup(self, instance_id: Optional[int] = None): d = q.popleft() StreamWrapper.close_streams(d) - def __del__(self): + def __del__(self) -> None: self._cleanup() @@ -623,7 +623,7 @@ class MultiplexerIterDataPipe(IterDataPipe): [0, 10, 20, 1, 11, 21, 2, 12, 22] """ - def __init__(self, *datapipes): + def __init__(self, *datapipes) -> None: self.datapipes = datapipes self.buffer: list = [] # Store values to be yielded only when every iterator provides one @@ -640,7 +640,7 @@ def __iter__(self): yield from self.buffer self.buffer.clear() - def __len__(self): + def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): return min(len(dp) for dp in self.datapipes) * len(self.datapipes) else: @@ -667,7 +667,7 @@ def __setstate__(self, state): ) = state self.buffer = [] - def __del__(self): + def __del__(self) -> None: self.buffer.clear() @@ -695,7 +695,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]): datapipes: tuple[IterDataPipe] - def __init__(self, *datapipes: IterDataPipe): + def __init__(self, *datapipes: IterDataPipe) -> None: if not all(isinstance(dp, IterDataPipe) for dp in datapipes): raise TypeError( "All inputs are required to be `IterDataPipe` for `ZipIterDataPipe`." diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 5b627a190e8a8..1d8efef4849bf 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -50,7 +50,7 @@ def __init__( mode: str = "r", encoding: Optional[str] = None, length: int = -1, - ): + ) -> None: super().__init__() self.datapipe: Iterable[str] = datapipe self.mode: str = mode diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py index 4c3af4f12a81f..ece25b3467cdb 100644 --- a/torch/utils/data/datapipes/iter/streamreader.py +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -32,7 +32,7 @@ class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]): def __init__( self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: Optional[int] = None - ): + ) -> None: self.datapipe = datapipe self.chunk = chunk diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index 21a412ff91609..c11d0bcd17d99 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -37,7 +37,7 @@ class ConcaterMapDataPipe(MapDataPipe): datapipes: tuple[MapDataPipe] - def __init__(self, *datapipes: MapDataPipe): + def __init__(self, *datapipes: MapDataPipe) -> None: if len(datapipes) == 0: raise ValueError("Expected at least one DataPipe, but got nothing") if not all(isinstance(dp, MapDataPipe) for dp in datapipes): diff --git a/torch/utils/data/datapipes/utils/decoder.py b/torch/utils/data/datapipes/utils/decoder.py index f4cc55838ae08..3b907ffebdd22 100644 --- a/torch/utils/data/datapipes/utils/decoder.py +++ b/torch/utils/data/datapipes/utils/decoder.py @@ -168,7 +168,7 @@ class ImageHandler: - pilrgba: pil None rgba """ - def __init__(self, imagespec): + def __init__(self, imagespec) -> None: if imagespec not in list(imagespecs.keys()): raise AssertionError(f"unknown image specification: {imagespec}") self.imagespec = imagespec.lower() @@ -335,13 +335,13 @@ class Decoder: handlers until some handler returns something other than None. """ - def __init__(self, *handler, key_fn=extension_extract_fn): + def __init__(self, *handler, key_fn=extension_extract_fn) -> None: self.handlers = list(handler) if handler else [] self.key_fn = key_fn # Insert new handler from the beginning of handlers list to make sure the new # handler having the highest priority - def add_handler(self, *handler): + def add_handler(self, *handler) -> None: if not handler: return self.handlers = list(handler) + self.handlers diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index b77ff892e6662..c800dd6a05826 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -205,7 +205,7 @@ def __init__(self, *tensors: Tensor) -> None: def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) - def __len__(self): + def __len__(self) -> int: return self.tensors[0].size(0) @@ -292,7 +292,7 @@ def __getitems__(self, indices: list): tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] return tuple_batch - def __len__(self): + def __len__(self) -> int: return self._length @@ -327,7 +327,7 @@ def __init__(self, datasets: Iterable[Dataset]) -> None: raise AssertionError("ConcatDataset does not support IterableDataset") self.cumulative_sizes = self.cumsum(self.datasets) - def __len__(self): + def __len__(self) -> int: return self.cumulative_sizes[-1] def __getitem__(self, idx): @@ -374,7 +374,7 @@ def __iter__(self): raise AssertionError("ChainDataset only supports IterableDataset") yield from d - def __len__(self): + def __len__(self) -> int: total = 0 for d in self.datasets: if not isinstance(d, IterableDataset): @@ -412,7 +412,7 @@ def __getitems__(self, indices: list[int]) -> list[_T_co]: else: return [self.dataset[self.indices[idx]] for idx in indices] - def __len__(self): + def __len__(self) -> int: return len(self.indices) diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 8867109c1e0b7..052db781d6a8d 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -15,7 +15,7 @@ DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]] -def _stub_unpickler(): +def _stub_unpickler() -> str: return "STUB" diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index bb97558256bec..9030150116800 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -58,7 +58,7 @@ def apply_sharding( """ graph = traverse_dps(datapipe) - def _helper(graph, prev_applied=None): + def _helper(graph, prev_applied=None) -> None: for dp, sub_graph in graph.values(): applied = None if _is_sharding_datapipe(dp): diff --git a/torch/utils/data/standard_pipes.ipynb b/torch/utils/data/standard_pipes.ipynb index c40058bca7699..e05b602c840bd 100644 --- a/torch/utils/data/standard_pipes.ipynb +++ b/torch/utils/data/standard_pipes.ipynb @@ -24,7 +24,7 @@ "source": [ "# Example IterDataPipe\n", "class ExampleIterPipe(IterDataPipe):\n", - " def __init__(self, range = 20):\n", + " def __init__(self, range = 20) -> None:\n", " self.range = range\n", " def __iter__(self):\n", " yield from self.range" diff --git a/torch/utils/data/typing.ipynb b/torch/utils/data/typing.ipynb index 1b1aa8c9da72f..0f546a2b3c3b5 100644 --- a/torch/utils/data/typing.ipynb +++ b/torch/utils/data/typing.ipynb @@ -33,7 +33,7 @@ "import functools\n", "ipython = get_ipython()\n", "def showtraceback(self, exc_tuple=None, filename=None, tb_offset=None,\n", - " exception_only=False, running_compiled_code=False):\n", + " exception_only=False, running_compiled_code=False) -> None:\n", " try:\n", " try:\n", " etype, value, tb = self._get_exc_info(exc_tuple)\n", @@ -227,7 +227,7 @@ "metadata": {}, "outputs": [], "source": [ - "def print_helper(cls, obj):\n", + "def print_helper(cls, obj) -> None:\n", " print(f\"DataPipe[{cls.type}]\\nInstance type: {obj.type}\")" ] }, @@ -313,7 +313,7 @@ "\n", "class DP(IterDataPipe):\n", " @argument_validation\n", - " def __init__(self, dp: IterDataPipe[Union[int, tuple]]):\n", + " def __init__(self, dp: IterDataPipe[Union[int, tuple]]) -> None:\n", " self.dp = dp\n", "\n", " def __iter__(self):\n", @@ -411,7 +411,7 @@ "from torch.utils.data import runtime_validation, runtime_validation_disabled\n", "\n", "class DP(IterDataPipe[tuple[int, T_co]]):\n", - " def __init__(self, datasource):\n", + " def __init__(self, datasource) -> None:\n", " self.ds = datasource\n", "\n", " @runtime_validation\n", @@ -606,7 +606,7 @@ ], "source": [ "class DP(IterDataPipe[T]):\n", - " def __init__(self, ds):\n", + " def __init__(self, ds) -> None:\n", " self.ds = ds\n", "\n", " def __iter__(self):\n", @@ -621,7 +621,7 @@ "outputs": [], "source": [ "class DP(IterDataPipe[T]):\n", - " def __init__(self, ds):\n", + " def __init__(self, ds) -> None:\n", " self.ds = ds\n", "\n", " @runtime_validation\n", @@ -744,7 +744,7 @@ "outputs": [], "source": [ "class DP(IterDataPipe[Union[int, str]]):\n", - " def __init__(self, label):\n", + " def __init__(self, label) -> None:\n", " if label == 'int':\n", " self.reinforce_type(int)\n", " elif label == 'str':\n", diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index 3d51d9efb339f..5b4f55d8c88dd 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -7,7 +7,7 @@ class FileBaton: """A primitive, file-based synchronization utility.""" - def __init__(self, lock_file_path, wait_seconds=0.1, warn_after_seconds=None): + def __init__(self, lock_file_path, wait_seconds=0.1, warn_after_seconds=None) -> None: """ Create a new :class:`FileBaton`. @@ -23,7 +23,7 @@ def __init__(self, lock_file_path, wait_seconds=0.1, warn_after_seconds=None): self.fd = None self.warn_after_seconds = warn_after_seconds - def try_acquire(self): + def try_acquire(self) -> bool | None: """ Try to atomically create a file under exclusive access. @@ -37,7 +37,7 @@ def try_acquire(self): except FileExistsError: return False - def wait(self): + def wait(self) -> None: """ Periodically sleeps for a certain amount until the baton is released. @@ -56,7 +56,7 @@ def wait(self): f'{self.warn_after_seconds} seconds.', stacklevel=2) has_warned = True - def release(self): + def release(self) -> None: """Release the baton and removes its file.""" if self.fd is not None: os.close(self.fd) diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 634e03439d4f4..41e5bc056e258 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -38,7 +38,7 @@ def register_fun(flop_formula: Callable[_P, _T]) -> Callable[_P, _T]: if not get_raw: flop_formula = shape_wrapper(flop_formula) - def register(target): + def register(target) -> None: if not isinstance(target, torch._ops.OpOverloadPacket): raise ValueError( f"register_flop_formula(targets): expected each target to be " @@ -624,7 +624,7 @@ def convert_num_with_suffix(number, suffix): # Return the value and the suffix as a string return value + suffixes[index] -def convert_to_percent_str(num, denom): +def convert_to_percent_str(num, denom) -> str: if denom == 0: return "0%" return f"{num / denom:.2%}" @@ -664,7 +664,7 @@ def __init__( mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None, depth: int = 2, display: bool = True, - custom_mapping: Optional[dict[Any, Any]] = None): + custom_mapping: Optional[dict[Any, Any]] = None) -> None: super().__init__() self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.depth = depth @@ -787,7 +787,7 @@ def _count_flops(self, func_packet, out, args, kwargs): class _FlopCounterMode(TorchDispatchMode): supports_higher_order_operators = True - def __init__(self, counter: FlopCounterMode): + def __init__(self, counter: FlopCounterMode) -> None: self.counter = counter def _execute_with_isolated_flop_counting(self, branch_fn, operands): diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 93ce3c50dfcf2..29d02cb30d338 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -47,12 +47,12 @@ class CurrentState(Enum): DONE = 2 class HipifyResult: - def __init__(self, current_state, hipified_path): + def __init__(self, current_state, hipified_path) -> None: self.current_state = current_state self.hipified_path = hipified_path self.status = "" - def __str__(self): + def __str__(self) -> str: return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}") HipifyFinalResult = dict[str, HipifyResult] @@ -75,11 +75,11 @@ def __str__(self): class InputError(Exception): # Exception raised for errors in the input. - def __init__(self, message): + def __init__(self, message) -> None: super().__init__(message) self.message = message - def __str__(self): + def __str__(self) -> str: return f"Input error: {self.message}" @@ -109,7 +109,7 @@ class bcolors: # keep them (e.g. in the CI), this can be used to remove files. class GeneratedFileCleaner: """Context Manager to clean up generated files""" - def __init__(self, keep_intermediates=False): + def __init__(self, keep_intermediates=False) -> None: self.keep_intermediates = keep_intermediates self.files_to_clean = set() self.dirs_to_clean = [] @@ -123,7 +123,7 @@ def open(self, fn, *args, **kwargs): # pyrefly: ignore [not-iterable] return open(fn, *args, **kwargs) - def makedirs(self, dn, exist_ok=False): + def makedirs(self, dn, exist_ok=False) -> None: parent, n = os.path.split(dn) if not n: parent, n = os.path.split(parent) @@ -222,7 +222,7 @@ def preprocess_file_and_save_result( HIPIFY_FINAL_RESULT[fin_path] = result -def compute_stats(stats): +def compute_stats(stats) -> None: unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]} # Print the number of unsupported calls @@ -616,7 +616,7 @@ def get_hip_file_path(rel_filepath, is_pytorch_extension=False): return os.path.join(dirpath, root + ext) -def is_out_of_place(rel_filepath): +def is_out_of_place(rel_filepath) -> bool: if os.path.isabs(rel_filepath): raise AssertionError("rel_filepath must be a relative path") if rel_filepath.startswith("torch/"): @@ -629,7 +629,7 @@ def is_out_of_place(rel_filepath): # Keep this synchronized with includes/ignores in build_amd.py -def is_pytorch_file(rel_filepath): +def is_pytorch_file(rel_filepath) -> bool: if os.path.isabs(rel_filepath): raise AssertionError("rel_filepath must be a relative path") if rel_filepath.startswith("aten/"): @@ -653,7 +653,7 @@ def is_cusparse_file(rel_filepath): return False -def is_special_file(rel_filepath): +def is_special_file(rel_filepath) -> bool: if is_pytorch_file(rel_filepath): if "sparse" in rel_filepath.lower(): return True @@ -678,20 +678,20 @@ class TrieNode: A special char '' represents end of word """ - def __init__(self): + def __init__(self) -> None: self.children = {} class Trie: """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern. The corresponding Regex should match much faster than a simple Regex union.""" - def __init__(self): + def __init__(self) -> None: """Initialize the trie with an empty root node.""" self.root = TrieNode() self._hash = hashlib.md5(usedforsecurity=False) self._digest = self._hash.digest() - def add(self, word): + def add(self, word) -> None: """Add a word to the Trie. """ self._hash.update(word.encode()) self._digest = self._hash.digest() @@ -1011,7 +1011,7 @@ def repl(m): hipify_result.current_state = CurrentState.DONE return hipify_result -def file_specific_replacement(filepath, search_string, replace_string, strict=False): +def file_specific_replacement(filepath, search_string, replace_string, strict=False) -> None: with openf(filepath, "r+") as f: contents = f.read() if strict: @@ -1023,7 +1023,7 @@ def file_specific_replacement(filepath, search_string, replace_string, strict=Fa f.truncate() -def file_add_header(filepath, header): +def file_add_header(filepath, header) -> None: with openf(filepath, "r+") as f: contents = f.read() if header[0] != "<" and header[-1] != ">": @@ -1089,7 +1089,7 @@ def extract_arguments(start, string): return arguments -def str2bool(v): +def str2bool(v : str) -> bool: """ArgumentParser doesn't support type=bool. Thus, this helper method will convert from possible string types to True / False.""" if v.lower() in ('yes', 'true', 't', 'y', '1'): diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index 3c022a4e85508..8e89d3ec9b3a0 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -80,7 +80,7 @@ def unserializable_hook(f): return f -def warn_if_has_hooks(tensor): +def warn_if_has_hooks(tensor) -> None: if tensor._backward_hooks: for k in tensor._backward_hooks: hook = tensor._backward_hooks[k] @@ -101,7 +101,7 @@ class BackwardHook: - Calling the user hook once both output and input gradients are available """ - def __init__(self, module, user_hooks, user_pre_hooks): + def __init__(self, module, user_hooks, user_pre_hooks) -> None: self.user_hooks = user_hooks self.user_pre_hooks = user_pre_hooks self.module = module @@ -124,7 +124,7 @@ def _unpack_none(self, indices, values): return tuple(res) - def _set_user_hook(self, grad_fn): + def _set_user_hook(self, grad_fn) -> None: def hook(grad_input, _): if self.grad_outputs is None: # This happens because the gradient in your nn.Module flows to @@ -190,7 +190,7 @@ def _apply_on_tensors(self, fn, args): return out, tensors_idx def setup_input_hook(self, args): - def fn(grad_fn): + def fn(grad_fn) -> None: self._set_user_hook(grad_fn) res, input_idx = self._apply_on_tensors(fn, args) @@ -199,7 +199,7 @@ def fn(grad_fn): return res def setup_output_hook(self, args): - def fn(grad_fn): + def fn(grad_fn) -> None: def hook(_, grad_output): self.grad_outputs = self._pack_with_none(self.output_tensors_index, grad_output, diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index b6b09937eb90c..11bb4e442b296 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -3,7 +3,7 @@ class MkldnnLinear(torch.jit.ScriptModule): - def __init__(self, dense_module, dtype): + def __init__(self, dense_module, dtype) -> None: super().__init__() self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) if dense_module.bias is not None: @@ -39,7 +39,7 @@ class _MkldnnConvNd(torch.jit.ScriptModule): __constants__ = ['stride', 'padding', 'dilation', 'groups'] - def __init__(self, dense_module): + def __init__(self, dense_module) -> None: super().__init__() self.stride = dense_module.stride @@ -74,7 +74,7 @@ def forward(self, x): class MkldnnConv1d(_MkldnnConvNd): - def __init__(self, dense_module, dtype): + def __init__(self, dense_module, dtype) -> None: super().__init__(dense_module) self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) @@ -87,7 +87,7 @@ def __setstate__(self, state): class MkldnnConv2d(_MkldnnConvNd): - def __init__(self, dense_module, dtype): + def __init__(self, dense_module, dtype) -> None: super().__init__(dense_module) self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight( @@ -109,7 +109,7 @@ def __setstate__(self, state): self.training = state[2] class MkldnnConv3d(_MkldnnConvNd): - def __init__(self, dense_module, dtype): + def __init__(self, dense_module, dtype) -> None: super().__init__(dense_module) self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight( @@ -134,7 +134,7 @@ def __setstate__(self, state): class MkldnnBatchNorm(torch.jit.ScriptModule): __constants__ = ['exponential_average_factor', 'eps'] - def __init__(self, dense_module): + def __init__(self, dense_module) -> None: super().__init__() if dense_module.training: @@ -186,7 +186,7 @@ def forward(self, x): ) class MkldnnPrelu(torch.jit.ScriptModule): - def __init__(self, dense_module, dtype): + def __init__(self, dense_module, dtype) -> None: super().__init__() self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 2ba3ea36088ce..16d1ab1c6dd1a 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -428,7 +428,7 @@ def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): return page -def main(argv, *, stdout=None): +def main(argv, *, stdout=None) -> None: warnings.warn("torch.utils.model_dump is deprecated and will be removed in a future PyTorch release.", stacklevel=2) parser = argparse.ArgumentParser() parser.add_argument("--style", choices=["json", "html"]) diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py index 4c7dec0481522..7b5a8aad4dda9 100644 --- a/torch/utils/module_tracker.py +++ b/torch/utils/module_tracker.py @@ -68,12 +68,12 @@ def __init__(self) -> None: self._has_callback = False self._hooks: list[RemovableHandle] = [] - def _maybe_set_engine_callback(self): + def _maybe_set_engine_callback(self) -> None: # This assumes no concurrent calls to backward if self._has_callback: return - def callback(): + def callback() -> None: self.parents = {"Global"} self._has_callback = False @@ -99,7 +99,7 @@ def _get_mod_name(self, mod): return mod_name def _get_append_fn(self, name, is_bw): - def fn(*args): + def fn(*args) -> None: if is_bw: self._maybe_set_engine_callback() if name in self.parents: @@ -113,7 +113,7 @@ def fn(*args): return fn def _get_pop_fn(self, name, is_bw): - def fn(*args): + def fn(*args) -> None: if name in self.parents: self.parents.remove(name) else: @@ -125,7 +125,7 @@ def fn(*args): return fn - def _fw_pre_hook(self, mod, input): + def _fw_pre_hook(self, mod, input) -> None: name = self._get_mod_name(mod) self._get_append_fn(name, False)() @@ -136,7 +136,7 @@ def _fw_pre_hook(self, mod, input): register_multi_grad_hook(tensors, self._get_pop_fn(name, True)) ) - def _fw_post_hook(self, mod, input, output): + def _fw_post_hook(self, mod, input, output) -> None: name = self._get_mod_name(mod) self._get_pop_fn(name, False)() diff --git a/torch/utils/show_pickle.py b/torch/utils/show_pickle.py index cd8b6c2b8ab9c..269ba3fbda423 100644 --- a/torch/utils/show_pickle.py +++ b/torch/utils/show_pickle.py @@ -11,14 +11,14 @@ __all__ = ["FakeObject", "FakeClass", "DumpUnpickler", "main"] class FakeObject: - def __init__(self, module, name, args): + def __init__(self, module, name, args) -> None: self.module = module self.name = name self.args = args # NOTE: We don't distinguish between state never set and state set to None. self.state = None - def __repr__(self): + def __repr__(self) -> str: state_str = "" if self.state is None else f"(state={self.state!r})" return f"{self.module}.{self.name}{self.args!r}{state_str}" @@ -26,7 +26,7 @@ def __setstate__(self, state): self.state = state @staticmethod - def pp_format(printer, obj, stream, indent, allowance, context, level): + def pp_format(printer, obj, stream, indent, allowance, context, level) -> None: if not obj.args and obj.state is None: stream.write(repr(obj)) return @@ -45,12 +45,12 @@ def pp_format(printer, obj, stream, indent, allowance, context, level): class FakeClass: - def __init__(self, module, name): + def __init__(self, module, name) -> None: self.module = module self.name = name self.__new__ = self.fake_new # type: ignore[assignment] - def __repr__(self): + def __repr__(self) -> str: return f"{self.module}.{self.name}" def __call__(self, *args): @@ -66,7 +66,7 @@ def __init__( file, *, catch_invalid_utf8=False, - **kwargs): + **kwargs) -> None: super().__init__(file, **kwargs) self.catch_invalid_utf8 = catch_invalid_utf8 @@ -82,7 +82,7 @@ def persistent_load(self, pid): # from their pickle (__getstate__) functions. Install a custom loader # for strings that catches the decode exception and replaces it with # a sentinel object. - def load_binunicode(self): + def load_binunicode(self) -> None: strlen, = struct.unpack(" sys.maxsize: raise Exception("String too long.") # noqa: TRY002 @@ -104,7 +104,7 @@ def dump(cls, in_stream, out_stream): return value -def main(argv, output_stream=None): +def main(argv, output_stream=None) -> int | None: if len(argv) != 2: # Don't spam stderr if not using stdout. if output_stream is not None: diff --git a/torch/utils/tensorboard/_embedding.py b/torch/utils/tensorboard/_embedding.py index 28385426c280c..73413e219d0ef 100644 --- a/torch/utils/tensorboard/_embedding.py +++ b/torch/utils/tensorboard/_embedding.py @@ -21,7 +21,7 @@ def _gfile_join(a, b): return fs.join(a, b) -def make_tsv(metadata, save_path, metadata_header=None): +def make_tsv(metadata, save_path, metadata_header=None) -> None: if not metadata_header: metadata = [str(x) for x in metadata] else: @@ -37,7 +37,7 @@ def make_tsv(metadata, save_path, metadata_header=None): # https://github.com/tensorflow/tensorboard/issues/44 image label will be squared -def make_sprite(label_img, save_path): +def make_sprite(label_img, save_path) -> None: from PIL import Image from io import BytesIO @@ -74,13 +74,13 @@ def get_embedding_info(metadata, label_img, subdir, global_step, tag): return info -def write_pbtxt(save_path, contents): +def write_pbtxt(save_path, contents) -> None: config_path = _gfile_join(save_path, "projector_config.pbtxt") with tf.io.gfile.GFile(config_path, "wb") as f: f.write(tf.compat.as_bytes(contents)) -def make_mat(matlist, save_path): +def make_mat(matlist, save_path) -> None: with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f: for x in matlist: x = [str(i.item()) for i in x] diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 859f80e691ce5..5a052016130b1 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -42,7 +42,7 @@ def __init__( tensor_size=None, op_type="UnSpecified", attributes="", - ): + ) -> None: # TODO; Specify a __slots__ for this class or potentially # used namedtuple instead self.debugName = debugName @@ -52,7 +52,7 @@ def __init__( self.attributes = attributes self.scope = scope - def __repr__(self): + def __repr__(self) -> str: repr = [] repr.append(str(type(self))) repr.extend( @@ -64,7 +64,7 @@ def __repr__(self): class NodePy(NodeBase): - def __init__(self, node_cpp, valid_methods): + def __init__(self, node_cpp, valid_methods) -> None: super().__init__(node_cpp) valid_methods = valid_methods[:] self.inputs = [] @@ -89,7 +89,7 @@ def __init__(self, node_cpp, valid_methods): class NodePyIO(NodePy): - def __init__(self, node_cpp, input_or_output=None): + def __init__(self, node_cpp, input_or_output=None) -> None: super().__init__(node_cpp, methods_IO) try: tensor_size = node_cpp.type().sizes() @@ -109,7 +109,7 @@ def __init__(self, node_cpp, input_or_output=None): class NodePyOP(NodePy): - def __init__(self, node_cpp): + def __init__(self, node_cpp) -> None: super().__init__(node_cpp, methods_OP) # Replace single quote which causes strange behavior in TensorBoard # TODO: See if we can remove this in the future @@ -140,32 +140,32 @@ class GraphPy: and scope_name_appeared. """ - def __init__(self): + def __init__(self) -> None: self.nodes_op = [] self.nodes_io = OrderedDict() self.unique_name_to_scoped_name = {} self.shallowest_scope_name = "default" self.scope_name_appeared = [] - def append(self, x): + def append(self, x) -> None: if isinstance(x, NodePyIO): self.nodes_io[x.debugName] = x if isinstance(x, NodePyOP): self.nodes_op.append(x) - def printall(self): + def printall(self) -> None: print("all nodes") for node in self.nodes_op: print(node) for key in self.nodes_io: print(self.nodes_io[key]) - def find_common_root(self): + def find_common_root(self) -> None: for fullscope in self.scope_name_appeared: if fullscope: self.shallowest_scope_name = fullscope.split("/")[0] - def populate_namespace_from_OP_to_IO(self): + def populate_namespace_from_OP_to_IO(self) -> None: for node in self.nodes_op: for node_output, outputSize in zip(node.outputs, node.outputstensor_size, strict=True): self.scope_name_appeared.append(node.scopeName) diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 1b6a2bb9bb66f..74befc366c199 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -115,7 +115,7 @@ def _tensor_to_list(t: torch.Tensor) -> list[Any]: } -def _calc_scale_factor(tensor): +def _calc_scale_factor(tensor) -> int: converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor return 1 if converted.dtype == np.uint8 else 255 diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 0f533ae5b0f57..2dd8ac3db667b 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -50,7 +50,7 @@ class FileWriter: training. """ - def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""): + def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix="") -> None: """Create a `FileWriter` and an event file. On construction the writer creates a new event file in `log_dir`. @@ -81,7 +81,7 @@ def get_logdir(self): """Return the directory where event file will be written.""" return self.event_writer.get_logdir() - def add_event(self, event, step=None, walltime=None): + def add_event(self, event, step=None, walltime=None) -> None: """Add an event to the event file. Args: @@ -98,7 +98,7 @@ def add_event(self, event, step=None, walltime=None): event.step = int(step) self.event_writer.add_event(event) - def add_summary(self, summary, global_step=None, walltime=None): + def add_summary(self, summary, global_step=None, walltime=None) -> None: """Add a `Summary` protocol buffer to the event file. This method wraps the provided summary in an `Event` protocol buffer @@ -114,7 +114,7 @@ def add_summary(self, summary, global_step=None, walltime=None): event = event_pb2.Event(summary=summary) self.add_event(event, global_step, walltime) - def add_graph(self, graph_profile, walltime=None): + def add_graph(self, graph_profile, walltime=None) -> None: """Add a `Graph` and step stats protocol buffer to the event file. Args: @@ -133,7 +133,7 @@ def add_graph(self, graph_profile, walltime=None): event = event_pb2.Event(tagged_run_metadata=trm) self.add_event(event, None, walltime) - def add_onnx_graph(self, graph, walltime=None): + def add_onnx_graph(self, graph, walltime=None) -> None: """Add a `Graph` protocol buffer to the event file. Args: @@ -144,7 +144,7 @@ def add_onnx_graph(self, graph, walltime=None): event = event_pb2.Event(graph_def=graph.SerializeToString()) self.add_event(event, None, walltime) - def flush(self): + def flush(self) -> None: """Flushes the event file to disk. Call this method to make sure that all pending events have been written to @@ -152,14 +152,14 @@ def flush(self): """ self.event_writer.flush() - def close(self): + def close(self) -> None: """Flushes the event file to disk and close the file. Call this method when you do not need the summary writer anymore. """ self.event_writer.close() - def reopen(self): + def reopen(self) -> None: """Reopens the EventFileWriter. Can be called after `close()` to add more events in the same directory. @@ -188,7 +188,7 @@ def __init__( max_queue=10, flush_secs=120, filename_suffix="", - ): + ) -> None: """Create a `SummaryWriter` that will write out events and summaries to the event file. Args: @@ -299,7 +299,7 @@ def add_hparams( hparam_domain_discrete=None, run_name=None, global_step=None, - ): + ) -> None: """Add a set of hyperparameters to be compared in TensorBoard. Args: @@ -355,7 +355,7 @@ def add_scalar( walltime=None, new_style=False, double_precision=False, - ): + ) -> None: """Add scalar data to summary. Args: @@ -388,7 +388,7 @@ def add_scalar( ) self._get_file_writer().add_summary(summary, global_step, walltime) - def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): + def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None) -> None: """Add many scalar data to summary. Args: @@ -439,7 +439,7 @@ def add_tensor( tensor, global_step=None, walltime=None, - ): + ) -> None: """Add tensor data to summary. Args: @@ -473,7 +473,7 @@ def add_histogram( bins="tensorflow", walltime=None, max_bins=None, - ): + ) -> None: """Add histogram to summary. Args: @@ -520,7 +520,7 @@ def add_histogram_raw( bucket_counts, global_step=None, walltime=None, - ): + ) -> None: """Add histogram with raw data. Args: @@ -585,7 +585,7 @@ def add_histogram_raw( def add_image( self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW" - ): + ) -> None: """Add image data to summary. Note that this requires the ``pillow`` package. @@ -636,7 +636,7 @@ def add_image( def add_images( self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW" - ): + ) -> None: """Add batched image data to summary. Note that this requires the ``pillow`` package. @@ -688,7 +688,7 @@ def add_image_with_boxes( rescale=1, dataformats="CHW", labels=None, - ): + ) -> None: """Add image and draw bounding boxes on the image. Args: @@ -767,7 +767,7 @@ def add_figure( dataformats="CHW", ) - def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): + def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None) -> None: """Add video data to summary. Note that this requires the ``moviepy`` package. @@ -789,7 +789,7 @@ def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): def add_audio( self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None - ): + ) -> None: """Add audio data to summary. Args: @@ -807,7 +807,7 @@ def add_audio( audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime ) - def add_text(self, tag, text_string, global_step=None, walltime=None): + def add_text(self, tag, text_string, global_step=None, walltime=None) -> None: """Add text data to summary. Args: @@ -826,13 +826,13 @@ def add_text(self, tag, text_string, global_step=None, walltime=None): text(tag, text_string), global_step, walltime ) - def add_onnx_graph(self, prototxt): + def add_onnx_graph(self, prototxt) -> None: torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) def add_graph( self, model, input_to_model=None, verbose=False, use_strict_trace=True - ): + ) -> None: """Add graph data to summary. Args: @@ -867,7 +867,7 @@ def add_embedding( global_step=None, tag="default", metadata_header=None, - ): + ) -> None: """Add embedding projector data to summary. Args: @@ -973,7 +973,7 @@ def add_pr_curve( num_thresholds=127, weights=None, walltime=None, - ): + ) -> None: """Add precision recall curve. Plotting a precision-recall curve lets you understand your model's @@ -1026,7 +1026,7 @@ def add_pr_curve_raw( num_thresholds=127, weights=None, walltime=None, - ): + ) -> None: """Add precision recall curve with raw data. Args: @@ -1062,7 +1062,7 @@ def add_pr_curve_raw( def add_custom_scalars_multilinechart( self, tags, category="default", title="untitled" - ): + ) -> None: """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*. Args: @@ -1080,7 +1080,7 @@ def add_custom_scalars_multilinechart( def add_custom_scalars_marginchart( self, tags, category="default", title="untitled" - ): + ) -> None: """Shorthand for creating marginchart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*, @@ -1101,7 +1101,7 @@ def add_custom_scalars_marginchart( layout = {category: {title: ["Margin", tags]}} self._get_file_writer().add_summary(custom_scalars(layout)) - def add_custom_scalars(self, layout): + def add_custom_scalars(self, layout) -> None: """Create special chart by collecting charts tags in 'scalars'. NOTE: This function can only be called once for each SummaryWriter() object. @@ -1134,7 +1134,7 @@ def add_mesh( config_dict=None, global_step=None, walltime=None, - ): + ) -> None: """Add meshes or 3D point clouds to TensorBoard. The visualization is based on Three.js, @@ -1192,7 +1192,7 @@ def add_mesh( mesh(tag, vertices, colors, faces, config_dict), global_step, walltime ) - def flush(self): + def flush(self) -> None: """Flushes the event file to disk. Call this method to make sure that all pending events have been written to @@ -1203,7 +1203,7 @@ def flush(self): for writer in self.all_writers.values(): writer.flush() - def close(self): + def close(self) -> None: if self.all_writers is None: return # ignore double close for writer in self.all_writers.values(): diff --git a/torch/utils/throughput_benchmark.py b/torch/utils/throughput_benchmark.py index 3f06f6220eef2..d4b94e0b13a39 100644 --- a/torch/utils/throughput_benchmark.py +++ b/torch/utils/throughput_benchmark.py @@ -3,7 +3,7 @@ import torch._C -def format_time(time_us=None, time_ms=None, time_s=None): +def format_time(time_us=None, time_ms=None, time_s=None) -> str: """Define time formatting.""" if sum([time_us is not None, time_ms is not None, time_s is not None]) != 1: raise AssertionError("Expected only one of time_us, time_ms, time_s is given.") @@ -27,7 +27,7 @@ def format_time(time_us=None, time_ms=None, time_s=None): class ExecutionStats: - def __init__(self, c_stats, benchmark_config): + def __init__(self, c_stats, benchmark_config) -> None: self._c_stats = c_stats self.benchmark_config = benchmark_config @@ -49,7 +49,7 @@ def total_time_seconds(self): return self.num_iters * ( self.latency_avg_ms / 1000.0) / self.benchmark_config.num_calling_threads - def __str__(self): + def __str__(self) -> str: return '\n'.join([ "Average latency per example: " + format_time(time_ms=self.latency_avg_ms), f"Total number of iterations: {self.num_iters}", @@ -93,7 +93,7 @@ class ThroughputBenchmark: >>> print("Number of iterations: {}".format(stats.num_iters)) """ - def __init__(self, module): + def __init__(self, module) -> None: if isinstance(module, torch.jit.ScriptModule): self._benchmark = torch._C.ThroughputBenchmark(module._c) else: @@ -109,7 +109,7 @@ def run_once(self, *args, **kwargs): """ return self._benchmark.run_once(*args, **kwargs) - def add_input(self, *args, **kwargs): + def add_input(self, *args, **kwargs) -> None: """ Store a single input to a module into the benchmark memory and keep it there. diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 9587a8d682e5b..0002b40025c18 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -15,14 +15,14 @@ def observe_garbage(observer): enabled = True - def disable(): + def disable() -> None: # when GC runs during exit, things like `sys` will already be unloaded # so we have to disable the callback to avoid hitting errors. nonlocal enabled enabled = False atexit.register(disable) - def gc_callback(phase, info): + def gc_callback(phase, info) -> None: nonlocal enabled if not enabled: return @@ -66,7 +66,7 @@ def do_collect(*args, **kwargs): gc.callbacks.append(gc_callback) # provide a way to disarm the callback - def remove(): + def remove() -> None: gc.callbacks.remove(gc_callback) return remove @@ -103,15 +103,15 @@ def annotated_references(obj): """ references: dict[int, list[str]] = {} - def add_reference(name, obj): + def add_reference(name, obj) -> None: references.setdefault(id(obj), []).append(name) - def add_attrs(*attrs): + def add_attrs(*attrs) -> None: for attr in attrs: if hasattr(obj, attr): add_reference(attr, getattr(obj, attr)) - def add_cell_references(): + def add_cell_references() -> None: try: add_attrs("cell_contents") except ValueError: @@ -121,7 +121,7 @@ def add_cell_references(): # annotate pass - def add_function_references(): + def add_function_references() -> None: add_attrs("__defaults__", "__closure__", "__globals__", @@ -134,23 +134,23 @@ def add_function_references(): "__kwdefaults__") - def add_sequence_references(): + def add_sequence_references() -> None: for position, item in enumerate(obj): add_reference(f"[{position}]", item) - def add_dict_references(): + def add_dict_references() -> None: for key, value in obj.items(): add_reference("key", key) add_reference(f"[{repr(key)}]", value) - def add_set_references(): + def add_set_references() -> None: for elt in obj: add_reference("element", elt) - def add_bound_method_references(): + def add_bound_method_references() -> None: add_attrs("__self__", "__func__", "im_class") - def add_weakref_references(): + def add_weakref_references() -> None: # For subclasses of weakref, we can't reliably distinguish the # callback (if any) from other attributes. if type(obj) is weakref.ref: @@ -160,7 +160,7 @@ def add_weakref_references(): add_reference("__callback__", target) - def add_frame_references(): + def add_frame_references() -> None: f_locals = obj.f_locals add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals") # Some badly-behaved code replaces the f_locals dict with @@ -170,7 +170,7 @@ def add_frame_references(): for name, local in obj.f_locals.items(): add_reference(f"local {name}", local) - def add_getset_descriptor_references(): + def add_getset_descriptor_references() -> None: add_attrs("__objclass__", "__name__", "__doc__") type_based_references = { @@ -473,7 +473,7 @@ def to_html(nodes): def observe_tensor_cycles(callback): torch.cuda.memory._record_memory_history(max_entries=100000) - def observer(garbage): + def observer(garbage) -> None: if garbage: if not any(is_cuda_tensor(obj) for obj in garbage): logger.info("No CUDA Tensors found in garbage") @@ -497,7 +497,7 @@ def warn_tensor_cycles(): """ logger.info("Watching Python reference cycles for CUDA Tensors.") - def write_and_log(html): + def write_and_log(html) -> None: with NamedTemporaryFile('w', suffix='.html', delete=False) as f: f.write(html) logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index cd829e531b46c..f71912b59f53a 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -28,7 +28,7 @@ class _IterationGuard: # exits. # This technique should be relatively thread-safe (since sets are). - def __init__(self, weakcontainer): + def __init__(self, weakcontainer) -> None: # Don't create cycles self.weakcontainer = ref(weakcontainer) @@ -75,7 +75,7 @@ def __exit__(self, e, t, b): class WeakIdRef(weakref.ref): __slots__ = ["_id"] - def __init__(self, key, callback=None): + def __init__(self, key, callback=None) -> None: # Unlike stock weakref, which preserves hash semantics of the # original object but lazily defers hash calls until the first # time the user attempts to hash the weakref, we can eagerly @@ -119,7 +119,7 @@ def __eq__(self, other): class _WeakHashRef(weakref.ref): __slots__ = ["_id"] - def __init__(self, key, callback=None): + def __init__(self, key, callback=None) -> None: # Unlike stock weakref, which preserves hash semantics of the # original object but lazily defers hash calls until the first # time the user attempts to hash the weakref, we can eagerly @@ -151,12 +151,12 @@ def __eq__(self, other): # This is directly adapted from cpython/Lib/weakref.py class WeakIdKeyDictionary(MutableMapping): - def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED + def __init__(self, dict=None, ref_type=WeakIdRef) -> None: # CHANGED self.data = {} self.ref_type = ref_type # CHANGED - def remove(k, selfref=ref(self)): + def remove(k, selfref=ref(self)) -> None: self = selfref() if self is not None: if self._iterating: @@ -175,7 +175,7 @@ def remove(k, selfref=ref(self)): if dict is not None: self.update(dict) - def _commit_removals(self): + def _commit_removals(self) -> None: # NOTE: We don't need to call this method before mutating the dict, # because a dead weakref never compares equal to a live weakref, # even if they happened to refer to equal objects. @@ -193,29 +193,29 @@ def _commit_removals(self): except KeyError: pass - def _scrub_removals(self): + def _scrub_removals(self) -> None: d = self.data self._pending_removals = [k for k in self._pending_removals if k in d] self._dirty_len = False - def __delitem__(self, key): + def __delitem__(self, key) -> None: self._dirty_len = True del self.data[self.ref_type(key)] # CHANGED def __getitem__(self, key): return self.data[self.ref_type(key)] # CHANGED - def __len__(self): + def __len__(self) -> int: if self._dirty_len and self._pending_removals: # self._pending_removals may still contain keys which were # explicitly removed, we have to scrub them (see issue #21173). self._scrub_removals() return len(self.data) - len(self._pending_removals) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__} at {id(self):#x}>" - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: self.data[self.ref_type(key, self._remove)] = value # CHANGED def copy(self): @@ -243,7 +243,7 @@ def __deepcopy__(self, memo): def get(self, key, default=None): return self.data.get(self.ref_type(key), default) # CHANGED - def __contains__(self, key): + def __contains__(self, key) -> bool: try: wr = self.ref_type(key) # CHANGED except TypeError: @@ -303,7 +303,7 @@ def setdefault(self, key, default=None): self.ref_type(key, self._remove), default ) # CHANGED - def update(self, dict=None, **kwargs): # type: ignore[override] + def update(self, dict=None, **kwargs) -> None: # type: ignore[override] d = self.data if dict is not None: if not hasattr(dict, "items"): @@ -351,7 +351,7 @@ class TensorWeakRef: ref: WeakRef[Tensor] - def __init__(self, tensor: Tensor): + def __init__(self, tensor: Tensor) -> None: if not isinstance(tensor, Tensor): raise AssertionError(f"expected torch.Tensor, got {type(tensor)}.") self.ref = weakref.ref(tensor) diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 6f1671e4e7a43..194684e3388e4 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -94,7 +94,7 @@ def is_initialized(): return _initialized and not _is_in_bad_fork() -def _lazy_call(callable, **kwargs): +def _lazy_call(callable, **kwargs) -> None: if is_initialized(): callable() else: @@ -108,7 +108,7 @@ def _lazy_call(callable, **kwargs): _queued_calls.append((callable, traceback.format_stack())) -def init(): +def init() -> None: r"""Initialize PyTorch's XPU state. This is a Python API about lazy initialization that avoids initializing XPU until the first time it is accessed. Does nothing if the XPU state is @@ -117,7 +117,7 @@ def init(): _lazy_init() -def _lazy_init(): +def _lazy_init() -> None: global _initialized, _queued_calls if is_initialized() or hasattr(_tls, "is_initializing"): return @@ -158,7 +158,7 @@ def _lazy_init(): class _DeviceGuard: - def __init__(self, index: int): + def __init__(self, index: int) -> None: self.idx = index self.prev_idx = -1 @@ -178,7 +178,7 @@ class device: this argument is a negative integer or ``None``. """ - def __init__(self, device: Any): + def __init__(self, device: Any) -> None: self.idx = _get_device_index(device, optional=True) self.prev_idx = -1 @@ -200,7 +200,7 @@ class device_of(device): obj (Tensor or Storage): object allocated on the selected device. """ - def __init__(self, obj): + def __init__(self, obj) -> None: idx = obj.get_device() if obj.is_xpu else -1 super().__init__(idx) @@ -324,7 +324,7 @@ class StreamContext: cur_stream: Optional["torch.xpu.Stream"] - def __init__(self, stream: Optional["torch.xpu.Stream"]): + def __init__(self, stream: Optional["torch.xpu.Stream"]) -> None: self.stream = stream self.idx = _get_device_index(None, True) if self.idx is None: @@ -362,7 +362,7 @@ def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext: return StreamContext(stream) -def _set_stream_by_id(stream_id, device_index, device_type): +def _set_stream_by_id(stream_id, device_index, device_type) -> None: r"""set stream specified by the stream id, device index and device type Args: stream_id (int): not visible to the user, used to assigned to the specific stream. @@ -376,7 +376,7 @@ def _set_stream_by_id(stream_id, device_index, device_type): ) -def set_stream(stream: Stream): +def set_stream(stream: Stream) -> None: r"""Set the current stream.This is a wrapper API to set the stream. Usage of this function is discouraged in favor of the ``stream`` context manager. @@ -495,7 +495,7 @@ def _set_rng_state_offset( """ final_device = _get_device(device) - def cb(): + def cb() -> None: default_generator = _get_generator(final_device) default_generator.set_offset(offset) diff --git a/torch/xpu/random.py b/torch/xpu/random.py index 8cd74d385defd..ec770225aef39 100644 --- a/torch/xpu/random.py +++ b/torch/xpu/random.py @@ -53,7 +53,7 @@ def set_rng_state( elif isinstance(device, int): device = torch.device("xpu", device) - def cb(): + def cb() -> None: idx = device.index if idx is None: idx = current_device() @@ -87,7 +87,7 @@ def manual_seed(seed: int) -> None: """ seed = int(seed) - def cb(): + def cb() -> None: idx = current_device() default_generator = torch.xpu.default_generators[idx] default_generator.manual_seed(seed) @@ -105,7 +105,7 @@ def manual_seed_all(seed: int) -> None: """ seed = int(seed) - def cb(): + def cb() -> None: for i in range(device_count()): default_generator = torch.xpu.default_generators[i] default_generator.manual_seed(seed) @@ -123,7 +123,7 @@ def seed() -> None: the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. """ - def cb(): + def cb() -> None: idx = current_device() default_generator = torch.xpu.default_generators[idx] default_generator.seed() @@ -137,7 +137,7 @@ def seed_all() -> None: It's safe to call this function if XPU is not available; in that case, it is silently ignored. """ - def cb(): + def cb() -> None: random_seed = 0 seeded = False for i in range(device_count()): diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index a1d78305f0a5e..2f10f1f14dd67 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -96,7 +96,7 @@ def __eq__(self, o): def __hash__(self): return hash((self.sycl_queue, self.device)) - def __repr__(self): + def __repr__(self) -> str: return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" @@ -166,7 +166,7 @@ def synchronize(self) -> None: def _as_parameter_(self): return ctypes.c_void_p(self.sycl_event) - def __repr__(self): + def __repr__(self) -> str: if self.sycl_event: return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})" else: From eaf4815c1f05987d4433cd280a3f8d307085da69 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 03:37:51 +0000 Subject: [PATCH 177/651] Remove workarounds for older Python (#167173) This PR removes workarounds for older Python. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167173 Approved by: https://github.com/albanD --- torch/__init__.py | 1 - torch/distributed/checkpoint/default_planner.py | 10 +++------- torch/nn/modules/container.py | 3 +-- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index b64961a9c56f6..6ce2549964abb 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1026,7 +1026,6 @@ def sym_fresh_size(expr): except ImportError: import torch._C as _C_for_compiled_check - # The __file__ check only works for Python 3.7 and above. if _C_for_compiled_check.__file__ is None: raise ImportError( textwrap.dedent( diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index ee0029ec7d63b..2f68e7f842264 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -5,7 +5,6 @@ import io import logging import operator -from collections import ChainMap from functools import reduce from typing import Any, cast, Optional, Union @@ -137,12 +136,9 @@ def _create_global_plan( global_plan, metadata = create_default_global_save_plan(deduped_plans) if self.flatten_state_dict: - # | does not work for Python 3.8 or older version. - # merged_mappings = reduce( - # lambda x, y: x | y, (p.planner_data for p in global_plan) - # ) - planner_data_dict = [p.planner_data for p in global_plan] - merged_mappings = dict(ChainMap(*planner_data_dict)) + merged_mappings = reduce( + lambda x, y: x | y, (p.planner_data for p in global_plan) + ) metadata = dataclasses.replace(metadata, planner_data=merged_mappings) if not _validate_global_plan(global_plan, metadata): diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 1132dc2bb0d4d..7200db650dfa6 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -519,8 +519,7 @@ class ModuleDict(Module): :meth:`~torch.nn.ModuleDict.update`). Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping - types (e.g., Python's plain ``dict`` before Python version 3.6) does not - preserve the order of the merged mapping. + types does not preserve the order of the merged mapping. Args: modules (iterable, optional): a mapping (dictionary) of (string: module) From 7aedf3a576420936c5fa080312590a8e6542b295 Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Fri, 7 Nov 2025 03:49:39 +0000 Subject: [PATCH 178/651] Update torch-xpu-ops commit pin (#166945) Update the torch-xpu-ops commit to [intel/torch-xpu-ops@9aac5a](https://github.com/intel/torch-xpu-ops/commit/9aac5a1ddf50d75f929d572df51bb368b32da14e), includes: - Enable FP8 concat/where/flip/index_put/index.Tensor on XPU backend - Remove BUILD_SPLIT_KERNEL_LIB flag - Fix the initialization order of ProcessGroupXCCL - Separates communication initialization logic from getXCCLComm - Fix segmentation fault in NLLLoss kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/166945 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 91aee0c2a0ffa..a5031de150288 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -8d373ba272f9fed348c7684bac4a0c2663844bbd +9aac5a1ddf50d75f929d572df51bb368b32da14e From d325aa1877be5a83e9d0f0756d564c3937825ccd Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 7 Nov 2025 04:22:53 +0000 Subject: [PATCH 179/651] [vision hash update] update the pinned vision hash (#167032) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167032 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 183e9fb4b06e1..1c6bf359618d5 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -cfbc5c2f1c798991715a6b06bb3ce46478c4487c +ca2212438fdd8ce29b66999ed70ed54b0f9372d1 From 8a721888286734470a7d99d5a752ebf020a4bc90 Mon Sep 17 00:00:00 2001 From: Dev Sashidhar Date: Fri, 7 Nov 2025 04:43:08 +0000 Subject: [PATCH 180/651] Raise error for 1D (size > 1) -> 0D parameter loads (#166335) Fixes #165873 # Title Fix load_state_dict: raise error for 1D (size > 1) -> 0D parameter loads ## Summary This PR fixes a bug where loading a 1D tensor (size > 1) into a scalar (0D) parameter would silently take the first element instead of raising an error. The fix preserves backward compatibility for 1D tensors of size 1 while catching genuine shape mismatches. ## Motivation Previously, loading a 1D tensor like torch.randn(32000) into a 0D scalar parameter would silently slice the first element, leading to silent data loss and potential bugs. This change ensures users get a clear error when there's a genuine shape mismatch. ## Behavior change Before: 1D tensor (any length) -> 0D scalar -> silently coerced using input_param[0] After: - 1D tensor (size == 1) -> 0D scalar -> allowed (backward compatibility) - 1D tensor (size > 1) -> 0D scalar -> raises RuntimeError with size mismatch message In torch/nn/modules/module.py, _load_from_state_dict, added input_param.shape[0] == 1 check to the backward compatibility condition to only allow single-element 1D tensors. ## Tests Added test_scalar_param_1d_tensor_raises to verify that loading 1D tensors of size > 1 raises an error, while size 1 loads successfully. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166335 Approved by: https://github.com/mikaylagawarecki --- test/nn/test_load_state_dict.py | 23 +++++++++++++++++++++++ torch/nn/modules/module.py | 1 + 2 files changed, 24 insertions(+) diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 074ac6273689a..3d20787ac4456 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -60,6 +60,29 @@ def test_load_state_dict_type(self): ): m.load_state_dict(2) + @swap([True, False]) + @skipIfTorchDynamo("dynamo installs weakrefs on some params") + def test_scalar_param_1d_tensor_raises(self): + class SimpleModule(nn.Module): + def __init__(self): + super().__init__() + self.threshold = nn.Parameter(torch.tensor(0.0)) + + def forward(self, x): + return x + + m = SimpleModule() + + # Test that [3] -> scalar raises error + sd = {"threshold": torch.randn(3)} + with self.assertRaisesRegex(RuntimeError, "size mismatch for threshold"): + m.load_state_dict(sd) + + # Test that [1] -> scalar is allowed (backward compatibility) + sd = {"threshold": torch.tensor([1.0])} + m.load_state_dict(sd) + self.assertEqual(m.threshold.item(), 1.0) + @swap([True, False]) @skipIfTorchDynamo("dynamo installs weakrefs on some params") def test_load_state_dict(self): diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 33bf35a1d852a..a3d723a0d294d 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2438,6 +2438,7 @@ def _load_from_state_dict( not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1 + and input_param.shape[0] == 1 ): input_param = input_param[0] From 3f03f84ce2f927a5a26e30c43a235a13571b7110 Mon Sep 17 00:00:00 2001 From: shunting314 Date: Wed, 5 Nov 2025 23:40:39 -0800 Subject: [PATCH 181/651] [inductor] fix dashbaord regression due to mix order reduction (#166938) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The PR includes a misc list of fixes for the regressions I see from the dashboard: 1. the dashboard may use very small shape for rmsnorm backward. The data set can be fully cached in L2 thus mix order reduction does not show much benefit and may even has worse perf. Disable mix order reduction for small workload 2. disable the autotuning of split size by default to avoid the compilation time hit 3. avoid mix order reduction if there is non-contiguous memory access. Previously the check is only done for shared buffers accessed by both reductions. It turns out to be necessary to expand the check for buffers only accessed by one reduction. Check test test_avoid_non_coalesced_access which is simplified from a TIMM model. Note that larger XBLOCK could fix the perf problem and make mix order reduction still applicable. But I don't think that's high priority. With larger XBLOCK, the kernel would consume much more shared memory/registers. That could also cause perf issue. Dashboard result [here](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2029%20Oct%202025%2003%3A40%3A22%20GMT&stopTime=Wed%2C%2005%20Nov%202025%2004%3A40%3A22%20GMT&granularity=hour&mode=training&dtype=amp&deviceName=cuda%20(h100)&lBranch=gh/shunting314/257/head&lCommit=b6f4a24ea5f7574d6b1d3b854022aa09d70593db&rBranch=main&rCommit=22a745737a09b0600bb0b85b4c0bbb9fb627f137). Screenshot 2025-11-04 at 10 58 48 PM - the perf drop for TIMM (default) is not real, it's due to one more model passed the accuracy test - the perf drop for HF (cudagraphs) is not real. I checked each individual models that showed regressed on the dashboard. And they fall into the following categories - showed regressed, but absolute execution get reduced. e.g. OPTForCausalLM - showed regressed, but has slight speedup on h100 dev server: MobileBertForMaskedLM . speedup from 57.847709ms to 56.711640 ms - showed regressed, but the PR does not change the kernels generated (skip mix order reduction due to small workload or other reasons). e.g. XGLMForCausalLM, AlbertForMaskedLM . Note that the neutral result on the dashboard is expected due to small workload size. For large workload, we see about 1.5x geomean for rmsnorm/layernorm backward on average and 2.2x for some shapes used by internal model. For 8GPU torchtitan training on llama3, we see 4% TPS (tokens per second) improvement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166938 Approved by: https://github.com/jansel ghstack dependencies: #166669 --- test/inductor/test_mix_order_reduction.py | 23 ++++++++++++++-- torch/_inductor/codegen/simd.py | 6 ++++- torch/_inductor/config.py | 8 ++++-- torch/_inductor/scheduler.py | 33 ++++++++++++++++++----- 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 0dcc37ee359d8..d7c5d886f1a2f 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -117,6 +117,25 @@ def outer_red(): metrics.codegen_mix_order_reduction, ) + def test_avoid_non_coalesced_access(self): + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x, y): + return (x + y).sum(dim=-1), x.sum(dim=(0, 1)) + + x = torch.randn(128, 256, 768, device=GPU_TYPE) + y = torch.randn(128, 768, 256, device=GPU_TYPE).transpose(1, 2) + self.check_numeric(f, (x, y)) + + # we skip mix order reduction for such kernel since + # we force XBLOCK to be 1, the access to tensor y would be + # very inefficient. + # TODO: support XBLOCK larger than 1. But in that case, we + # would have bigger restriction on rnumel to avoid exploding + # shared memory. + self.assertEqual(metrics.codegen_mix_order_reduction, 0) + @inductor_config.patch(coordinate_descent_tuning=True) def test_XBLOCK_coordest_tuning(self): """ @@ -199,8 +218,8 @@ def test_multi_workspace_allocation(self): def f(x, y): return x.sum(dim=0), x.sum(dim=1), y.sum(dim=0), y.sum(dim=1) - x = torch.randn(4096, 32, device=GPU_TYPE) - y = torch.randn(4098, 34, device=GPU_TYPE) + x = torch.randn(4096 * 64, 32, device=GPU_TYPE) + y = torch.randn(4098 * 64, 34, device=GPU_TYPE) self.check_numeric(f, (x, y)) expected_mix_order_reduction = ( diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 24394bc87cf41..14bf46db4a18c 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1648,7 +1648,11 @@ def _pick_split_size(): if ( not torch._inductor.config.deterministic and config.triton.mix_order_reduction_split_size is None - and config.triton.mix_order_reduction_autotune_split_size + and ( + config.triton.mix_order_reduction_autotune_split_size + or config.max_autotune + or config.coordinate_descent_tuning + ) ): def _bench(candidate_split_size): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 10e3d2bb5211a..8aaea94da1266 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1575,11 +1575,15 @@ class triton: enable_pdl = False mix_order_reduction = ( - os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1" + os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0" if is_fbcode() else "1") + == "1" ) mix_order_reduction_split_size: Optional[int] = None - mix_order_reduction_autotune_split_size = True + mix_order_reduction_autotune_split_size = ( + os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_AUTOTUNE_SPLIT_SIZE", "0") + == "1" + ) class aot_inductor: diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 020067c83999c..61fa832878ada 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -219,6 +219,9 @@ def has_common_read( # TODO add a cache @classmethod def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + """ + Check whether we can fuse two reductions with mix loop orders. + """ if not config.triton.mix_order_reduction: return False @@ -246,6 +249,13 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: nrow = sympy.Max(g1[0], g1[1]) ncol = sympy.Min(g1[0], g1[1]) + # the fused version has worse perf than non-fused version for + # small workload. When a workload is small enough, data can be + # fully cached by L2 + size_thres = 5 * 2**20 + if not V.graph.sizevars.statically_known_geq(nrow * ncol, size_thres): + return False + # We require more more row than columns since # 1, we prefer doing persistent reduction for each row # 2, we will split the reduction across the rows @@ -262,8 +272,19 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: (node1, node2) if g1[1] == ncol else (node2, node1) ) + # We previously only check the contiguous_node has contiguous + # access to common_reads. But that turns out to be not enough. + # The contiguous node may access a buffer that's node use by + # other_ndoe. If that ascess is non-contiugous, generating + # mix-order reduction can be inefficient especially when we + # force XBLOCK to be 1 + # if not all( + # cls.is_contiguous_load(buf, contiguous_node) for buf in common_reads + # ): + # return False if not all( - cls.is_contiguous_load(buf, contiguous_node) for buf in common_reads + cls.is_contiguous_load(dep.name, contiguous_node) + for dep in contiguous_node.read_writes.reads ): return False @@ -306,7 +327,6 @@ def are_mix_order_reductions( def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool: from torch._inductor.loop_body import MemoryUsageType - n_congituous_read = 0 for node in parent_node.get_nodes(): assert isinstance(node, SchedulerNode) loop_body = node._body @@ -328,10 +348,11 @@ def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool: var_symbols, var_symbols, ) - n_congituous_read += stride_vars[-1] == 1 - if n_congituous_read > 0: - return True - return False + + # stride==0 means a broadcast + if not (stride_vars[-1] == 0 or stride_vars[-1] == 1): + return False + return True @dataclasses.dataclass From bf8297afe0b2114f0658a94fcc900a0c5f4cb9f8 Mon Sep 17 00:00:00 2001 From: shunting314 Date: Thu, 6 Nov 2025 10:54:14 -0800 Subject: [PATCH 182/651] [inductor] let mix-order-red tune XBLOCK and num-stages (#167161) A few improvements for autotuning - while testing mix order reduction for internal workloads, Paul found that tuning num-stages could be very helpful for triton kernel. The idea is illustrated on his diff: https://www.internalfb.com/diff/D86341591 - when rnumel is small, larger XBLOCK could be helpful for perf This PR adds the ability to autotune num-stages and XBLOCK. This brings further 19% speedup for RMSNorm BWD on B200. Testing result: eager 11 data points compiled 11 data points, 17.07x speedup (was 14.39x before the PR. The PR brings further 19% speedup) quack 11 data points, 12.72x speedup liger 11 data points, 11.75x speedup compiled-no-fusion 11 data points, 9.93x speedup RMSNormBackward_bench Pull Request resolved: https://github.com/pytorch/pytorch/pull/167161 Approved by: https://github.com/jansel ghstack dependencies: #166669, #166938 --- test/inductor/test_mix_order_reduction.py | 61 ++++++++++++++++++- torch/_inductor/codegen/simd.py | 11 ++++ torch/_inductor/codegen/triton.py | 19 ++++-- torch/_inductor/config.py | 1 + .../runtime/coordinate_descent_tuner.py | 36 ++++++++++- torch/_inductor/runtime/triton_heuristics.py | 21 +++---- 6 files changed, 127 insertions(+), 22 deletions(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index d7c5d886f1a2f..592e42ce41735 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -117,6 +117,39 @@ def outer_red(): metrics.codegen_mix_order_reduction, ) + def test_xmask(self): + """ + Make sure xmask is setup properly + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=0), x.sum(dim=1) + + M, N = 32768 + 1023, 768 + EXTRA_ROW = 1 + buf = torch.randn(M + EXTRA_ROW, N, device=GPU_TYPE) + x = buf[:M, :] + # make sure wrong xmask error loud if read excess elements + buf[M:, :] = 1000000 + + opt_f = torch.compile( + f, + options={ + "triton.mix_order_reduction_initial_xblock": 2, + }, + ) + + ref = f(x) + act = opt_f(x) + + self.assertTrue(same(ref, act, tol=1e-3), f"ref:\n{ref}\nact:\n{act}") + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + def test_avoid_non_coalesced_access(self): if not inductor_config.triton.mix_order_reduction: self.skipTest("Mix order reduction not enabled") @@ -237,8 +270,23 @@ def f(x, y): ], ) @parametrize("split_reductions", (False, True)) - @parametrize("shape", ((32768, 2048), (32768, 768), (32769, 768))) - def test_rms_norm_bwd(self, wdtype, split_reductions, shape): + @parametrize("shape", ((32768, 2048), (32768, 768), (32768 + 1023, 768))) + @parametrize("max_autotune", (False, True)) + @parametrize("initial_xblock", (1, 2)) + def test_rms_norm_bwd( + self, wdtype, split_reductions, shape, max_autotune, initial_xblock + ): + # max_autotune can be slow and cost resource, trim down the tests + # for max autotune + if max_autotune and not ( + wdtype == torch.bfloat16 + and not split_reductions + and shape in ((32768, 768), (32769, 768)) + and initial_xblock == 1 + and inductor_config.triton.mix_order_reduction + ): + self.skipTest("Skip non-critical tests to save resources.") + def f(x, w, eps): orig_dtype = x.dtype @@ -267,6 +315,15 @@ def fwd_bwd(f): f, options={ "split_reductions": split_reductions, + "triton.mix_order_reduction_initial_xblock": initial_xblock, + **( + { + "max_autotune": True, + "coordinate_descent_tuning": True, + } + if max_autotune + else {} + ), }, ) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 14bf46db4a18c..f3b5de1f0ab46 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -594,6 +594,17 @@ def dense_size_list(self) -> list[str]: sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" return sizes + def create_constant_mask(self, entry) -> str: + x = entry.prefix + if entry.tensor_dim is None: + sizestr = self.dense_size_str() + return f"{x}mask = tl.full({sizestr}, True, tl.int1)" + sizes = ["None"] * self.triton_tensor_ndim() + sizes[entry.tensor_dim] = ":" + suffix = ", ".join(sizes) + out = f"{x}mask = tl.full([{x.upper()}BLOCK], True, tl.int1)[{suffix}]" + return out + def dense_size_str(self) -> str: sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8426f46887d06..4ac481478196a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -4571,14 +4571,20 @@ def codegen_body(self): ) accumname2var[name] = self.cse.namedvar(name, dtype=torch.float) self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)") - self.body.writeline("for _ in range(0, split_size, XBLOCK):") + self.body.writeline( + "for _ in tl.range(0, split_size, XBLOCK, num_stages=NUM_STAGES):" + ) with self.body.indent(offset=1): + # generate xmask if it's not constant + if not self._has_constant_xmask(): + entry = self.range_trees[0] + assert entry.prefix == "x" + x = entry.prefix + self.body.writeline(f"{x}mask = {entry.name} < {x}numel") self.body.splice(self.indexing_code) self.body.writelines( [ "xindex += XBLOCK", - # TODO we force XBLOCK==1 for now so there is - # no need to update the xmask ] ) self.body.splice(self.loads) @@ -5038,6 +5044,7 @@ def add_constexpr_arg(arg_name): if self.mix_order_reduction: add_constexpr_arg("RSPLIT_SIZE") + add_constexpr_arg("NUM_STAGES") triton_meta_signature = signature_to_meta( signature, size_dtype=self.index_dtype, argdefs=argdefs @@ -5586,9 +5593,9 @@ def iteration_ranges_codegen_header( ] ) if self._has_constant_mask(entry): - sizes = self.dense_size_str() - code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") - else: + code.writeline(self.create_constant_mask(entry)) + elif not (x == "x" and self.mix_order_reduction): + # mix order reduction should generate xmask inside the loop code.writeline(f"{x}mask = {entry.name} < {x}numel") diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8aaea94da1266..bfa854b37030d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1578,6 +1578,7 @@ class triton: os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0" if is_fbcode() else "1") == "1" ) + mix_order_reduction_initial_xblock = 1 mix_order_reduction_split_size: Optional[int] = None mix_order_reduction_autotune_split_size = ( diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 7ea22bdcddf0b..36bd64cbae280 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -53,6 +53,7 @@ def __init__( self, is_mm=False, is_native_matmul=False, + is_mix_order_reduction=False, name="unknown", size_hints=None, inductor_meta=None, @@ -65,6 +66,7 @@ def __init__( # tl.dot also does not support size smaller than 16; we put this restriction. self.is_native_matmul = is_native_matmul assert not (self.is_mm and self.is_native_matmul) + self.is_mix_order_reduction = is_mix_order_reduction self.cached_benchmark_results = {} self.name = name self.size_hints = size_hints @@ -123,6 +125,12 @@ def tunable_fields(self): out.append("num_stages") out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul + if self.is_mix_order_reduction: + # unlike TritonConfig.num_stages, this one is + # put in TritonConfig.kwargs["NUM_STAGES"] and is used to + # control the stage of pipelining of tl.range. + out.append("NUM_STAGES") + return [f for f in out if f not in self.frozen_fields] def value_too_large(self, name: str, val: int) -> bool: @@ -146,15 +154,23 @@ def value_too_small(self, name: str, val: int) -> bool: # Break if value becomes 0/neg return val <= 0 - def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): + def get_neighbour_values(self, name, orig_val, radius=None, include_self=False): """ Get neighbour values in 'radius' steps. The original value is not returned as it's own neighbour. """ + if radius is None: + radius = 1 + if name == "NUM_STAGES": + # we see cases that + # NUM_STAGES=1 is better than NUM_STAGES=2 + # while NUM_STAGES=1 is worse than NUM_STAGES=3 + radius = max(radius, 2) + assert radius >= 1 def update(cur_val, inc=True): - if name == "num_stages": + if name in ["num_stages", "NUM_STAGES"]: if inc: return cur_val + 1 else: @@ -191,6 +207,15 @@ def has_improvement(baseline, test): threshold = 0.001 # 0.1% return test is not None and test < baseline * (1 - threshold) + def is_valid_config(self, config) -> bool: + if self.is_mix_order_reduction: + # Mix order reduction has an extra constraint that + # we should not tune XBLOCK beyond RSPLIT_SIZE + xblock = config.kwargs["XBLOCK"] + split_size = config.kwargs["RSPLIT_SIZE"] + return xblock <= split_size + return True + def check_all_tuning_directions( self, # pyrefly: ignore [missing-attribute] @@ -209,10 +234,11 @@ def check_all_tuning_directions( old_value = get_field(best_config, field) if old_value is None: continue + radius = self.inductor_meta.get("coordinate_descent_search_radius", 1) candidate_values = self.get_neighbour_values( field, old_value, - radius=self.inductor_meta.get("coordinate_descent_search_radius", 1), + radius=radius, include_self=True, ) candidate_values_list.append(candidate_values) @@ -225,6 +251,8 @@ def check_all_tuning_directions( candidate_config = copy.deepcopy(best_config) for new_val, field in zip(choice, effective_fields): set_field(candidate_config, field, new_val) + if not self.is_valid_config(candidate_config): + continue cmp_res, candidate_timing = self.compare_config( func, candidate_config, best_config, best_timing ) @@ -302,6 +330,8 @@ def autotune( candidate_config = copy.deepcopy(best_config) set_field(candidate_config, name, next_val) + if not self.is_valid_config(candidate_config): + continue cmp_res, candidate_timing = self.compare_config( func, candidate_config, best_config, best_timing ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 363d62d02303d..c2709073a64c1 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -330,13 +330,14 @@ def __init__( log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) self.size_hints = size_hints + self.is_mix_order_reduction = self.inductor_meta.get("RSPLIT_SIZE") is not None self.coordesc_tuner = CoordescTuner( is_mm=False, is_native_matmul=triton_meta.get("native_matmul", False), + is_mix_order_reduction=self.is_mix_order_reduction, name=self.fn.__name__, size_hints=size_hints, inductor_meta=self.inductor_meta, - frozen_fields=self.get_coordesc_frozen_fields(), ) self.filename = filename @@ -366,13 +367,6 @@ def __init__( # Mode for launch grid calculation self.grid_mode: Literal["python", "cpp"] = "python" - def get_coordesc_frozen_fields(self) -> OrderedSet[str]: - out: OrderedSet[str] = OrderedSet() - if self.inductor_meta.get("RSPLIT_SIZE"): - # We fix XBLOCK for mix order reduction - out.add("XBLOCK") - return out - def is_statically_launchable(self): """ Checks if every compiled kernel is statically launchable, which @@ -3421,8 +3415,12 @@ def persistent_reduction( for c in configs: c.kwargs["RSPLIT_SIZE"] = inductor_meta.get("RSPLIT_SIZE") + c.kwargs["NUM_STAGES"] = 1 + # small XBLOCK to use less registers/smem - c.kwargs["XBLOCK"] = 1 + c.kwargs["XBLOCK"] = ( + torch._inductor.config.triton.mix_order_reduction_initial_xblock + ) rnumel_hint = size_hints["r0_"] @@ -3760,8 +3758,9 @@ class MixOrderReductionGrid(GridExpr): def generate(self, meta: dict[str, int]) -> None: split_size = meta.get("RSPLIT_SIZE") xblock = meta.get("XBLOCK") - assert split_size - assert xblock == 1, "Mix order reduction force XBLOCK=1 right now" + assert split_size, "Missing RSPLIT_SIZE" + assert xblock, "Missing XBLOCK" + assert split_size % xblock == 0, f"{split_size=}, {xblock=}" self.x_grid = self.ceildiv("xnumel", split_size) From 91b626e2ef549b0b2995be7f7502c2632414b687 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 6 Nov 2025 14:07:59 -0800 Subject: [PATCH 183/651] [dynamo] unimplemented -> unimplemented_v2 for the rest of variables/misc.py (#167001) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167001 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos --- test/dynamo/test_reorder_logs.py | 2 +- torch/_dynamo/graph_break_registry.json | 122 ++++++++++++++++++++++ torch/_dynamo/variables/misc.py | 130 +++++++++++++++++++----- 3 files changed, 225 insertions(+), 29 deletions(-) diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index be6bf8085af27..a147b216e7703 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -67,7 +67,7 @@ def test_ignore_logger(self, ignore_method, fn, should_ignore_logger): self.assertEqual(len(counters["graph_break"]), 0) else: self.assertIn("moo", printed_output) - self.assertEqual(len(counters["graph_break"]), 1) + self.assertGreater(len(counters["graph_break"]), 0) class ReorderLogsTests(torch._dynamo.test_case.TestCase): diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index b21d81910abb1..a37723614a8b7 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2950,5 +2950,127 @@ "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." ] } + ], + "GB0289": [ + { + "Gb_type": "unsupported method call on `typing` variable", + "Context": "typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}", + "Explanation": "`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.", + "Hints": [ + "Avoid calling the {name} method on {self.value}.", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0290": [ + { + "Gb_type": "attempted to trace numpy.* function as a method", + "Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + "Explanation": "Tracing numpy.* functions as methods is not supported.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0291": [ + { + "Gb_type": "logging.Logger method not supported for non-export cases", + "Context": "method: {self.value}.{name}, args: {args}, kwargs: {kwargs}", + "Explanation": "logging.Logger methods are not supported for non-export cases.", + "Hints": [ + "Add the logging method to `torch._dynamo.config.ignore_logger_methods." + ] + } + ], + "GB0292": [ + { + "Gb_type": "constant-like method call with unsupported return type", + "Context": "{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}", + "Explanation": "Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0293": [ + { + "Gb_type": "attempted to trace numpy function with config.trace_numpy=False", + "Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + "Explanation": "Attempted to trace numpy function {self.value} while `torch._dynamo.config.trace_numpy` was set to False.", + "Hints": [ + "Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions." + ] + } + ], + "GB0294": [ + { + "Gb_type": "attempted to trace numpy function unsupported by PyTorch", + "Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + "Explanation": "Can't find numpy numpy function {self.value} in torch._numpy.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0295": [ + { + "Gb_type": "cannot reconstruct NullVariable in Python < 3.11", + "Context": "", + "Explanation": "Attempted to generate PUSH_NULL instruction in Python < 3.11; where this instruction does not exist.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0296": [ + { + "Gb_type": "attempted to reorder a debugging function that can't actually be reordered", + "Context": "fn: {self.value}, args: {args}, kwargs: {kwargs}", + "Explanation": "`torch.compile` can only reorder functions where the arguments are Tensors, constants, or string formatters.", + "Hints": [ + "Avoid calling the logging function {self.value} with args that are not supported." + ] + } + ], + "GB0297": [ + { + "Gb_type": "random.Random() with improper arguments", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "random.Random() with > 1 arg or with kwargs is not supported.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0298": [ + { + "Gb_type": "attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True", + "Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + "Explanation": "Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` is set to True.", + "Hints": [ + "Set `torch._dynamo.config.use_numpy_random_stream` to False.", + "Avoid calling {self.value}." + ] + } + ], + "GB0299": [ + { + "Gb_type": "constant-like method call with non-constant args", + "Context": "{self._error_prefix}.{name}(*{args}, **{kwargs})", + "Explanation": "Attempted to call {self._error_prefix}.{name} with non-constant args.", + "Hints": [ + "Ensure that the args to the method call are constant (int, str, etc.)." + ] + } + ], + "GB0300": [ + { + "Gb_type": "numpy function that produces a const collection type encountered non-const arguments", + "Context": "numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + "Explanation": "numpy function {self.value} that produces a const collection type (e.g. np.dtype, np.iinfo/np.finfo) received arguments that are not constant.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } ] } diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 4845d5d9acc93..7942e2fbd7bfa 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -39,7 +39,7 @@ create_instruction, ) from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import ( @@ -1382,7 +1382,15 @@ def call_method( if name == "__getitem__" and len(args) == 1: new_typing = self.value[args[0].as_python_constant()] return TypingVariable(new_typing) - unimplemented("unsupported method call on typing variable") + unimplemented_v2( + gb_type="unsupported method call on `typing` variable", + context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}", + explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.", + hints=[ + f"Avoid calling the {name} method on {self.value}.", + *graph_break_hints.SUPPORTABLE, + ], + ) def var_getattr(self, tx: "InstructionTranslator", name: str): from .builder import SourcelessBuilder, VariableBuilder @@ -1493,16 +1501,28 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if not config.trace_numpy: - unimplemented(f"numpy.{self.value}()") + unimplemented_v2( + gb_type="attempted to trace numpy function with config.trace_numpy=False", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation=f"Attempted to trace numpy function {self.value} " + "while `torch._dynamo.config.trace_numpy` was set to False.", + hints=[ + "Set `torch._dynamo.config.trace_numpy` to True to trace numpy functions.", + ], + ) from ..utils import numpy_to_tensor_wrapper from .tensor import NumpyNdarrayVariable func = get_np_to_tnp_map().get(self.value) if func is None: - unimplemented( - f"Can't find numpy function {self.value} in torch._numpy. " - " Please file an issue to request support for this function." + unimplemented_v2( + gb_type="attempted to trace numpy function unsupported by PyTorch", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], ) # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo) @@ -1516,20 +1536,32 @@ def call_function( **{k: v.as_python_constant() for k, v in kwargs.items()}, ) ) - except NotImplementedError: - unimplemented( - f"{self.value.__name__} with non-const args: {args} {kwargs}" + except AsPythonConstantNotImplementedError: + unimplemented_v2( + gb_type="numpy function that produces a const collection type encountered non-const arguments", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"numpy function {self.value} that produces a const collection type " + "(e.g. np.dtype, np.iinfo/np.finfo) " + "received arguments that are not constant.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) else: if ( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream ): - msg = f"delegate '{func.__qualname__}' to NumPy itself via " - msg += ( - f"config.use_numpy_random_stream={config.use_numpy_random_stream}" + unimplemented_v2( + gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", + explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` " + "is set to True.", + hints=[ + "Set `torch._dynamo.config.use_numpy_random_stream` to False.", + f"Avoid calling {self.value}.", + ], ) - unimplemented(msg) args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs) @@ -1559,7 +1591,14 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - unimplemented("numpy") + unimplemented_v2( + gb_type="attempted to trace numpy.* function as a method", + context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="Tracing numpy.* functions as methods is not supported.", + hints=[ + *graph_break_hints.DIFFICULT, + ], + ) def as_python_constant(self): return self.value @@ -1584,7 +1623,15 @@ def __repr__(self) -> str: def reconstruct(self, codegen: "PyCodegen"): if sys.version_info < (3, 11): - unimplemented("cannot reconstruct NullVariable in < Python 3.11") + unimplemented_v2( + gb_type="cannot reconstruct NullVariable in Python < 3.11", + context="", + explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; " + "where this instruction does not exist.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) codegen.append_output(create_instruction("PUSH_NULL")) @@ -1665,9 +1712,14 @@ def call_function(self, tx: "InstructionTranslator", args, kwargs): return if not self.can_reorder_logs(self.value, args, kwargs): - unimplemented( - f"Reordering debugging function {self.value} " - f"with inputs {args} {kwargs} is not yet implemented." + unimplemented_v2( + gb_type="attempted to reorder a debugging function that can't actually be reordered", + context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}", + explanation="`torch.compile` can only reorder functions where the arguments " + "are Tensors, constants, or string formatters.", + hints=[ + f"Avoid calling the logging function {self.value} with args that are not supported.", + ], ) tx.debug_locals.append((self, list(args))) @@ -1719,10 +1771,13 @@ def call_method( function = getattr(method, "__func__", None) if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): return variables.ConstantVariable.create(None) - unimplemented( - "Logger not supported for non-export cases. " - "To avoid graph breaks caused by logger in compile-mode, it is recommended to" - " disable logging by adding logging methods to config.ignore_logger_methods" + unimplemented_v2( + gb_type="logging.Logger method not supported for non-export cases", + context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}", + explanation="logging.Logger methods are not supported for non-export cases.", + hints=[ + "Add the logging method to `torch._dynamo.config.ignore_logger_methods.", + ], ) @@ -1759,7 +1814,14 @@ def call_method( cargs = [x.as_python_constant() for x in args] ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} except NotImplementedError: - unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})") + unimplemented_v2( + gb_type="constant-like method call with non-constant args", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})", + explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.", + hints=[ + "Ensure that the args to the method call are constant (int, str, etc.).", + ], + ) result = getattr(self.value, name)(*cargs, **ckwargs) @@ -1768,7 +1830,14 @@ def call_method( if isinstance(result, re.Match): return ConstantRegexMatchVariable(result) - unimplemented(f"{self._error_prefix}.{name}() -> {result}") + unimplemented_v2( + gb_type="constant-like method call with unsupported return type", + context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}", + explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: result = getattr(self.value, name) @@ -1831,10 +1900,15 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def call_function(self, tx: "InstructionTranslator", args, kwargs): - if len(args) > 1: - unimplemented("random.Random() with > 1 arg") - elif kwargs: - unimplemented("random.Random() with kwargs") + if len(args) > 1 or kwargs: + unimplemented_v2( + gb_type="random.Random() with improper arguments", + context=f"args: {args}, kwargs: {kwargs}", + explanation="random.Random() with > 1 arg or with kwargs is not supported.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] return RandomVariable( seed=seed, mutation_type=variables.base.ValueMutationNew() From 643b3bc8f3cb97340404101b8d975a53ad653c1d Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 6 Nov 2025 14:51:19 -0800 Subject: [PATCH 184/651] [dynamo] unimplemented -> unimplemented_v2 in variables/higher_order_ops.py (#167146) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167146 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos ghstack dependencies: #167001 --- test/dynamo/test_higher_order_ops.py | 11 +- torch/_dynamo/graph_break_registry.json | 503 ++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 535 ++++++++++++++------ 3 files changed, 899 insertions(+), 150 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 204e5114320f6..2c348ffe388f0 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3354,7 +3354,7 @@ def outer_body_fn(x, y): x = torch.randn(2, 4) y = torch.ones(4) - msg = "hints_wrapper - key hints not provided" + msg = "hints_wrapper: improper args/kwargs" with self.assertRaisesRegex(RuntimeError, msg): torch.compile(fn_with_hints, backend=cnt)(x, y) @@ -4516,12 +4516,9 @@ def wrapper_fn(model, params, inputs, targets): model, params, inputs, targets ) self.assertEqual(len(counters["graph_break"]), 1) - self.assertEqual( - { - "torch.func.functional_call capture is disabled, it can be " - "turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1, - }, - dict(counters["graph_break"]), + self.assertIn( + "torch.func.functional_call capture is disabled", + next(iter(counters["graph_break"].keys())), ) self.assertEqual(actual, expected) diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index a37723614a8b7..638487e417e63 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3072,5 +3072,508 @@ "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." ] } + ], + "GB0301": [ + { + "Gb_type": "HOP: non torch.Tensor leaf", + "Context": "args types: {[type(a.realize()) for a in args]}", + "Explanation": "Expected all leaves to be of torch.Tensor type.", + "Hints": [] + } + ], + "GB0302": [ + { + "Gb_type": "HOP: non-callable variable", + "Context": "arg name: {arg_name}, func_var type: {str(func_var)}", + "Explanation": "{arg_name} should be a callable but is of type {str(func_var)}.", + "Hints": [] + } + ], + "GB0303": [ + { + "Gb_type": "torch.while_loop: improper args/kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "torch.while_loop expects 4 positional arguments (got {len(args)}) and no keyword arguments (got {len(kwargs)}) Usage: while_loop(cond_fn, body_fn, operands)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0304": [ + { + "Gb_type": "torch.while_loop: improper additional_inputs", + "Context": "str(additional_inputs)", + "Explanation": "Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0305": [ + { + "Gb_type": "invalid set_subgraph_inputs and sub_kwargs settings", + "Context": "set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + "Explanation": "`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + "Hints": [ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0306": [ + { + "Gb_type": "unsupported HigherOrderOperator", + "Context": "str(value)", + "Explanation": "Unable to create higher order operator variable for {value.__name__}.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0307": [ + { + "Gb_type": "unsupported HigherOrderOperator function call", + "Context": "str(self.value)", + "Explanation": "Unable to trace calling higher order operator variable for {self.value.__name__}.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0308": [ + { + "Gb_type": "torch.while_loop: unsupported cond_fn return type", + "Context": "str(cond_r)", + "Explanation": "Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0309": [ + { + "Gb_type": "torch.cond: improper args/kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "torch.cond expects 4 positional arguments (got {len(args)}) and no keyword arguments (got {len(kwargs)}) Usage: cond(pred, cond_fn, body_fn, operands)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0310": [ + { + "Gb_type": "torch.cond: improper predicate", + "Context": "str(pred)", + "Explanation": "Expected `pred` to be a bool or a boolean tensor with a single item but got {str(type(pred))} with original python type {str(pred.python_type())}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0311": [ + { + "Gb_type": "torch.cond: improper operands", + "Context": "str(operands)", + "Explanation": "Expected `operands` to be a list/tuple but got {operands.python_type()}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0312": [ + { + "Gb_type": "torch.cond: improper operands contents", + "Context": "str(operands)", + "Explanation": "Expected `operands` to be a list/tuple of pytrees that only consists of tensor leaves.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0313": [ + { + "Gb_type": "torch.cond: differing branch outputs", + "Context": "true_spec: {true_spec.treespec}, false_spec: {false_spec.treespec}, same_spec: {same_spec}", + "Explanation": "Expected branches to return the same pytree structure.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0314": [ + { + "Gb_type": "HOP body output unsupported", + "Context": "non-tensor outputs: {non_tensor_output}", + "Explanation": "HigherOrderOperator body's output must consist of tensors or ints/bools only but got {out.python_type()}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0315": [ + { + "Gb_type": "torch.associative_scan: improper xs", + "Context": "str(xs)", + "Explanation": "Expected xs to be a list/tuple but got {xs.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0316": [ + { + "Gb_type": "torch.associative_scan: improper additional_inputs", + "Context": "str(additional_inputs)", + "Explanation": "Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0317": [ + { + "Gb_type": "torch.associative_scan: zero-sized tensor", + "Context": "str(xs_vars[0])", + "Explanation": "associative_scan() operator doesn't support zero-sized tensors during tracing.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0318": [ + { + "Gb_type": "torch.associative_scan: combine_fn improper number of leaves", + "Context": "str(_combine_treespec.as_python_constant())", + "Explanation": "combine_fn needs to produce one pytree for the output but combine_fn produces the pytree {_combine_treespec.as_python_constant()}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0319": [ + { + "Gb_type": "torch.associative_scan: mismatched input/output tree structure", + "Context": "xs: {xs_treespec.as_python_constant()}, output: {_combine_treespec.as_python_constant()}", + "Explanation": "The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0320": [ + { + "Gb_type": "torch.scan: improper xs", + "Context": "str(xs)", + "Explanation": "Expected xs to be a list/tuple but got {xs.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0321": [ + { + "Gb_type": "torch.scan: improper init", + "Context": "str(init)", + "Explanation": "Expected init to be a list/tuple with at least one element but got {init.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0322": [ + { + "Gb_type": "torch.scan: no init leaves", + "Context": "", + "Explanation": "Expected init leaves.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0323": [ + { + "Gb_type": "torch.scan: improper additional_inputs", + "Context": "str(additional_inputs)", + "Explanation": "Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0324": [ + { + "Gb_type": "torch.scan: zero-sized tensor", + "Context": "str(xs_vars[0])", + "Explanation": "associative_scan() operator doesn't support zero-sized tensors during tracing.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0325": [ + { + "Gb_type": "torch.map: kwargs not supported", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "torch.map expects no keyword arguments (got {len(kwargs)})", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0326": [ + { + "Gb_type": "torch.map: improper inputs", + "Context": "str(sample_shape)", + "Explanation": "torch.map doesn't support scalar or non-zero sized tensors during tracing.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0327": [ + { + "Gb_type": "executorch_call_delegate: kwargs not supported", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "executorch_call_delegate expects no keyword arguments (got {len(kwargs)})", + "Hints": [] + } + ], + "GB0328": [ + { + "Gb_type": "torch.func.functional_call capture is disabled", + "Context": "", + "Explanation": "torch.func.functional_call capture is disabled", + "Hints": [ + "Set `torch._dynamo.config.inline_inbuilt_nn_modules=True` to enable." + ] + } + ], + "GB0329": [ + { + "Gb_type": "WrapHigherOrderVariable: kwargs unexpected", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "kwargs should have been flattened into lifted args.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0330": [ + { + "Gb_type": "wrap_with_set_grad_enabled: unexpected kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "wrap_with_set_grad_enabled expects no keyword arguments (got {len(kwargs)}).", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0331": [ + { + "Gb_type": "wrap_with_set_grad_enabled: non-constant grad_enabled", + "Context": "str(grad_enabled)", + "Explanation": "wrap_with_set_grad_enabled expects grad_enabled argument to be a constant.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0332": [ + { + "Gb_type": "wrap_with_set_grad_enabled: unexpected freevars", + "Context": "str(body_lifted_freevars)", + "Explanation": "wrap_with_set_grad_enabled expects no freevars.", + "Hints": [] + } + ], + "GB0333": [ + { + "Gb_type": "wrap_with_autocast: unexpected kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "wrap_with_autocast expects no keyword arguments (got {len(kwargs)}).", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0334": [ + { + "Gb_type": "wrap_with_autocast: unexpected freevars", + "Context": "str(body_lifted_freevars)", + "Explanation": "wrap_with_autocast expects no freevars.", + "Hints": [] + } + ], + "GB0335": [ + { + "Gb_type": "hints_wrapper: improper args/kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "hints_wrapper expects 3 positional arguments (got {len(args)}) and 1 keyword argument (got {len(kwargs)}). Usage: hints_wrapper(body_fn, args, kwargs, hints=...). args is expected to be list/tuple and kwargs is expected to be a dict.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0336": [ + { + "Gb_type": "out_dtype: unexpected kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "out_dtype expects no keyword arguments (got {len(kwargs)}).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0337": [ + { + "Gb_type": "strict_mode: unexpected kwargs", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "strict_mode higher order op expects no keyword arguments (got {len(kwargs)}).", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0338": [ + { + "Gb_type": "invoke_subgraph: kwargs unexpected", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "kwargs should have been flattened into lifted args.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0339": [ + { + "Gb_type": "torch.while_loop: infinite loop detected", + "Context": "str(cond_r)", + "Explanation": "Infinite loop detected because while_loop's cond_fn always returns the same value {pred}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0340": [ + { + "Gb_type": "torch.cond: unsupported branch return type", + "Context": "str(ret_val)", + "Explanation": "Expected branches to return a possibly nested pytree of tensors or constant ints.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0341": [ + { + "Gb_type": "torch.associative_scan: improper args", + "Context": "args: {args}", + "Explanation": "torch.associative_scan expects 2 positional arguments (got {len(args)}) Usage: associative_scan(combine_fn, xs)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0342": [ + { + "Gb_type": "torch.scan: improper combine_fn", + "Context": "str(combine_fn_var)", + "Explanation": "Expected combine_fn to be wrapped as functools.partial in scan user-facing api or a graph module if we're re-exporting but got {combine_fn_var.python_type()}.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0343": [ + { + "Gb_type": "torch.scan: improper combine_fn number of returns", + "Context": "str(combine_result_vars)", + "Explanation": "Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0344": [ + { + "Gb_type": "wrap_with_autocast: expected constant arg", + "Context": "str(args)", + "Explanation": "wrap_with_autocast expects device_type, dtype, enabled, and cache_enabled arguments to be constants.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0345": [ + { + "Gb_type": "strict_mode: improper args", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "strict_mode higher order op expects flat inputs (list/tuple/dict)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0346": [ + { + "Gb_type": "autograd.Function.apply: non-function or method forward", + "Context": "str(self.fwd_graph)", + "Explanation": "Expected forward function to be a function or method.", + "Hints": [] + } + ], + "GB0347": [ + { + "Gb_type": "autograd.Function.apply: _materialize_non_diff_grads mutation", + "Context": "", + "Explanation": "Mutations to autograd.Function.ctx._materialize_non_diff_grads are not supported.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0348": [ + { + "Gb_type": "autograd.Function.apply: non-function or method backward", + "Context": "str(self.bwd_graph)", + "Explanation": "Expected backward function to be a function or method.", + "Hints": [] + } + ], + "GB0349": [ + { + "Gb_type": "cannot unwrap variable for check_meta_consistency", + "Context": "str(var)", + "Explanation": "Expected {var} to be TensorVariable, SymNodeVariable, or ConstantVariable", + "Hints": [] + } + ], + "GB0350": [ + { + "Gb_type": "torch.cond: unsupported branch return type (constant non-int)", + "Context": "str(ret_val)", + "Explanation": "Constants returned from branches must be ints.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0351": [ + { + "Gb_type": "HOP body taking non-Tensor as input", + "Context": "str(sub_args)", + "Explanation": "{description} with body that accepts non-Tensors as input. Got type {a.python_type()} at index {idx}.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0352": [ + { + "Gb_type": "autograd.Function.apply: non-function or method backward (2)", + "Context": "str(self.bwd_graph)", + "Explanation": "Expected backward function to be a function or method.", + "Hints": [] + } ] } diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index c330a700fd66b..3f084cd00f59c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -48,10 +48,8 @@ from .. import graph_break_hints, variables from ..exc import ( - IncorrectUsage, ObservedException, UncapturedHigherOrderOpError, - unimplemented, unimplemented_v2, Unsupported, ) @@ -163,7 +161,12 @@ def _unwrap_var(var): elif isinstance(var, ConstantVariable): return var.as_python_constant() else: - unimplemented(f"Cannot unwrap var {var}") + unimplemented_v2( + gb_type="cannot unwrap variable for check_meta_consistency", + context=str(var), + explanation=f"Expected {var} to be TensorVariable, SymNodeVariable, or ConstantVariable", + hints=[], + ) unwrapped1 = [_unwrap_var(var) for var in vars1] unwrapped2 = [_unwrap_var(var) for var in vars2] @@ -295,8 +298,11 @@ def _check_all_tensorvariable(args): from . import TensorVariable if not all(type(a.realize()) is TensorVariable for a in args): - unimplemented( - f"Expected all leaves to be of torch.Tensor type, but got {[type(a.realize()) for a in args]}." + unimplemented_v2( + gb_type="HOP: non torch.Tensor leaf", + context=f"args types: {[type(a.realize()) for a in args]}", + explanation="Expected all leaves to be of torch.Tensor type.", + hints=[], ) @@ -307,8 +313,11 @@ def _check_supported_callable_arg( BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() ) if not is_callable: - unimplemented( - f"{arg_name} should be a Callable but is of type {str(func_var)}." + unimplemented_v2( + gb_type="HOP: non-callable variable", + context=f"arg name: {arg_name}, func_var type: {str(func_var)}", + explanation=f"{arg_name} should be a callable but is of type {str(func_var)}.", + hints=[], ) @@ -334,13 +343,16 @@ def _call_while_loop( ) args.append(v) - if kwargs: - unimplemented(f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}") - - if len(args) != 4: - unimplemented( - f"Expected 4 arguments but got {len(args)}.\n" - f"Usage: while_loop(cond_fn, body_fn, operands)", + if kwargs or len(args) != 4: + unimplemented_v2( + gb_type="torch.while_loop: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.while_loop expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: while_loop(cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # cond_fn and body_fn input check @@ -352,10 +364,13 @@ def _call_while_loop( # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented( - f"Expected additional_inputs to be a list/tuple but got " - f"{additional_inputs.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.while_loop: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) @@ -454,15 +469,25 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: cond_r.proxy.node.meta["example_value"], include_contiguity=False ) if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]): - unimplemented( - f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}" + unimplemented_v2( + gb_type="torch.while_loop: unsupported cond_fn return type", + context=str(cond_r), + explanation=f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) elif isinstance(cond_r, ConstantVariable): # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False pred = cond_r.as_python_constant() if pred: - unimplemented( - f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}" + unimplemented_v2( + gb_type="torch.while_loop: infinite loop detected", + context=str(cond_r), + explanation=f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) else: return operands @@ -770,9 +795,14 @@ def validate_args_and_maybe_create_graph_inputs( # If `a` cannot be put into a graph else: # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). - unimplemented( - f"{description} with body that accepts non-Tensors as input. " - f"Got: {a.python_type()}" + unimplemented_v2( + gb_type="HOP body taking non-Tensor as input", + context=str(sub_args), + explanation=f"{description} with body that accepts non-Tensors as input. " + f"Got type {a.python_type()} at index {idx}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) args.append(new_arg) return args @@ -928,7 +958,15 @@ def speculate_subgraph( # See NOTE [Temporary argument `set_subgraph_inputs`] if sub_kwargs and set_subgraph_inputs != "automatic": - unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.") + unimplemented_v2( + gb_type="invalid set_subgraph_inputs and sub_kwargs settings", + context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", + explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", + hints=[ + "Use `set_subgraph_inputs='automatic'` when passing `sub_kwargs`.", + *graph_break_hints.USER_ERROR, + ], + ) try: # ensure guards on args get installed in parent subgraph @@ -1215,7 +1253,14 @@ def make(value, source=None, **kwargs): if isinstance(value, BaseHOP): return BaseHOPVariable(value, source, **kwargs) - unimplemented(f"HigherOrderOperator {value.__name__}") + unimplemented_v2( + gb_type="unsupported HigherOrderOperator", + context=str(value), + explanation=f"Unable to create higher order operator variable for {value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) def call_function( self, @@ -1236,7 +1281,14 @@ def _call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - unimplemented(f"HigherOrderOperator {self.value.__name__}") + unimplemented_v2( + gb_type="unsupported HigherOrderOperator function call", + context=str(self.value), + explanation=f"Unable to trace calling higher order operator variable for {self.value.__name__}.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) def as_python_constant(self): return self.value @@ -1286,15 +1338,18 @@ def _call_function( ) args.append(v) - if kwargs: - unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") - # TODO(voz): Support fake tensor dispatch for recursive # ops - see torch/dispatch/_dispatcher.py - if len(args) != 4: - unimplemented( - f"Expected 4 arguments but got {len(args)}.\n" - f"Usage: cond(pred, true_fn, false_fn, operands)", + if len(args) != 4 or kwargs: + unimplemented_v2( + gb_type="torch.cond: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) " + f"and no keyword arguments (got {len(kwargs)}) " + "Usage: cond(pred, cond_fn, body_fn, operands)", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # Specialize into one of the branches since pred is constant @@ -1312,24 +1367,39 @@ def _call_function( # predicate if type(pred) not in (ConstantVariable, TensorVariable, SymNodeVariable): - unimplemented( - f"Expected pred to be bool or a boolean tensor with single " - f"item but got {str(type(pred))} " - f"with original python type {str(pred.python_type())}.", + unimplemented_v2( + gb_type="torch.cond: improper predicate", + context=str(pred), + explanation="Expected `pred` to be a bool or a boolean tensor with a single item " + f"but got {str(type(pred))} with original python type {str(pred.python_type())}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # operands if not isinstance(operands, (ListVariable, TupleVariable)): - unimplemented( - f"Expected operands to be a list/tuple but got " - f"{operands.python_type()}", + unimplemented_v2( + gb_type="torch.cond: improper operands", + context=str(operands), + explanation="Expected `operands` to be a list/tuple " + f"but got {operands.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) + operands_seq = operands.unpack_var_sequence(tx) if not only_consist_of( operands, (TensorVariable, ConstantVariable, SymNodeVariable) ): - unimplemented( - "Expect operands to be a tuple of pytrees that only consists of tensor leaves." + unimplemented_v2( + gb_type="torch.cond: improper operands contents", + context=str(operands), + explanation="Expected `operands` to be a list/tuple of pytrees that only consists of tensor leaves.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # branches @@ -1377,15 +1447,23 @@ def speculate_branch(branch): tx.fake_mode.epoch += 1 if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): - unimplemented( - "Expected branches to return a possibly nested pytree of tensors " - "or constant ints but it consists of others.", + unimplemented_v2( + gb_type="torch.cond: unsupported branch return type", + context=str(ret_val), + explanation="Expected branches to return a possibly nested pytree of tensors or constant ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) for ret in ret_val.unpack_var_sequence(tx): if isinstance(ret, ConstantVariable) and ret.python_type() is not int: - unimplemented( - "Expected branches to return a possibly nested pytree of tensors " - f"or constant ints but it consists of others {ret.python_type()}.", + unimplemented_v2( + gb_type="torch.cond: unsupported branch return type (constant non-int)", + context=str(ret_val), + explanation="Constants returned from branches must be ints.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) return ret_val, ret_spec, ret_graph, ret_lifted_freevars @@ -1405,7 +1483,14 @@ def speculate_branch(branch): ).as_python_constant() # 3.14: NotImplemented cannot be converted to bool if same_spec is not NotImplemented and not same_spec: - unimplemented("Expected branches to return the same pytree structure.") + unimplemented_v2( + gb_type="torch.cond: differing branch outputs", + context=f"true_spec: {true_spec.treespec}, false_spec: {false_spec.treespec}, same_spec: {same_spec}", + explanation="Expected branches to return the same pytree structure.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) ( true_graph, @@ -1500,8 +1585,14 @@ def validate_subgraph_output_types(output: VariableTracker): isinstance(out, ConstantVariable) and out.python_type() in (int, bool) ): continue - unimplemented( - f"HigherOrderOperator body's output must consist of tensors or ints only but got {out.python_type()}" + unimplemented_v2( + gb_type="HOP body output unsupported", + context=f"non-tensor outputs: {non_tensor_output}", + explanation="HigherOrderOperator body's output must consist of tensors or ints/bools only " + f"but got {out.python_type()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) @@ -1563,9 +1654,14 @@ def arg_extractor(combine_fn, xs, additional_inputs): # This is the standard case when the user calls the frontend # and the frontend invokes dynamo if len(args) != 2: - unimplemented( - f"Expected 2 positional arguments but got {len(args)}.\n" - f"Usage: associative_scan(combine_fn, xs)", + unimplemented_v2( + gb_type="torch.associative_scan: improper args", + context=f"args: {args}", + explanation=f"torch.associative_scan expects 2 positional arguments (got {len(args)}) " + "Usage: associative_scan(combine_fn, xs)", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) xs_treespec = args[0].keywords["spec"] @@ -1584,28 +1680,39 @@ def arg_extractor(combine_fn, xs, additional_inputs): # xs input check if not isinstance(xs, (ListVariable, TupleVariable)): - unimplemented( - f"Expected xs to be a list/tuple but got " - f"{xs.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.associative_scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) xs_vars = xs.unpack_var_sequence(tx) _check_all_tensorvariable(xs_vars) # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented( - f"Expected additional_inputs to be a list/tuple but got " - f"{additional_inputs.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.associative_scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) additional_inputs_vars = additional_inputs.unpack_var_sequence(tx) _check_all_tensorvariable(additional_inputs_vars) scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] if scan_length == 0: - unimplemented( - "associative_scan() operator doesn't support zero-sized tensors during tracing." + unimplemented_v2( + gb_type="torch.associative_scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # Trace the subgraph @@ -1652,9 +1759,14 @@ def arg_extractor(combine_fn, xs, additional_inputs): # Check whether the combine_fn returns one child tree for the output. if _combine_treespec.as_python_constant().num_leaves < 1: - unimplemented( - f"combine_fn needs to produce one pytree for the output " - f"but combine_fn produces the pytree {_combine_treespec.as_python_constant()}." + unimplemented_v2( + gb_type="torch.associative_scan: combine_fn improper number of leaves", + context=str(_combine_treespec.as_python_constant()), + explanation="combine_fn needs to produce one pytree for the output " + f"but combine_fn produces the pytree {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # Check whether the outs produced by combine_fn has the same treespec as xs @@ -1666,9 +1778,14 @@ def arg_extractor(combine_fn, xs, additional_inputs): ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( xs_treespec, _combine_treespec ).as_python_constant(): - unimplemented( - f"The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " - f"xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}." + unimplemented_v2( + gb_type="torch.associative_scan: mismatched input/output tree structure", + context=f"xs: {xs_treespec.as_python_constant()}, output: {_combine_treespec.as_python_constant()}", + explanation="The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " + f"xs: {xs_treespec.as_python_constant()} vs output: {_combine_treespec.as_python_constant()}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # We set include contiguity=False because we have vmap x HOP tests, where if @@ -1772,10 +1889,14 @@ def _check_combine_fn_is_normalized(combine_fn_var): variables.FunctoolsPartialVariable, ), ): - unimplemented( - f"Expected combine_fn to be wrapped as functools.partial in scan user-facing api " - f"or a graph module if we're re-exporting but got " - f"{combine_fn.python_type()}. Please report an issue to PyTorch if you're seeing this." + unimplemented_v2( + gb_type="torch.scan: improper combine_fn", + context=str(combine_fn_var), + explanation="Expected combine_fn to be wrapped as functools.partial in scan user-facing api " + f"or a graph module if we're re-exporting but got {combine_fn_var.python_type()}.", + hints=[ + *graph_break_hints.DIFFICULT, + ], ) return isinstance( combine_fn_var, @@ -1809,34 +1930,57 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) # xs input check if not isinstance(xs, (ListVariable, TupleVariable)): - unimplemented( - f"Expected xs to be a list/tuple but got " - f"{xs.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.scan: improper xs", + context=str(xs), + explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) # init input check if not isinstance(init, (ListVariable, TupleVariable)): - unimplemented( - f"Expected init to be a list/tuple with at least one element but got " - f"{init.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.scan: improper init", + context=str(init), + explanation=f"Expected init to be a list/tuple with at least one element but got {init.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) + if len(init_vars) == 0: - unimplemented( - "scan() operator requires init leaves. It seems to be an " - "internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.scan: no init leaves", + context="", + explanation="Expected init leaves.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) + # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented( - f"Expected additional_inputs to be a list/tuple but got " - f"{additional_inputs.python_type()}. It seems to be an " - f"internal error, please report an issue to PyTorch." + unimplemented_v2( + gb_type="torch.scan: improper additional_inputs", + context=str(additional_inputs), + explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) # scan_length check scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] if scan_length == 0: - unimplemented("NYI: scan() operator doesn't support zero scan_length.") + unimplemented_v2( + gb_type="torch.scan: zero-sized tensor", + context=str(xs_vars[0]), + explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + *graph_break_hints.SUPPORTABLE, + ], + ) _check_all_tensorvariable(init_vars) _check_all_tensorvariable(xs_vars) _check_all_tensorvariable(additional_inputs_vars) @@ -1885,8 +2029,13 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) else: if len(combine_result_vars) != 2: - unimplemented( - f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}" + unimplemented_v2( + gb_type="torch.scan: improper combine_fn number of returns", + context=str(combine_result_vars), + explanation=f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) carry_tree, out_vars = combine_result_vars carry_vars, _ = _make_inlined(tx, pytree.tree_flatten)( @@ -1970,8 +2119,13 @@ def _call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if len(kwargs) > 0: - unimplemented( - "torch.ops.higher_order.map: kwargs are not supported in the map operator." + unimplemented_v2( + gb_type="torch.map: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"torch.map expects no keyword arguments (got {len(kwargs)})", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) _check_supported_callable_arg(tx, args[0], "map_fn") @@ -1985,8 +2139,13 @@ def _call_function( sample_shape = get_fake_value(unpacked_xs[0].as_proxy().node, tx).size() if len(sample_shape) < 1 or sample_shape[0] == 0: - unimplemented( - "map() operator doesn't support scalar or zero-sized tensors during tracing." + unimplemented_v2( + gb_type="torch.map: improper inputs", + context=str(sample_shape), + explanation="torch.map doesn't support scalar or non-zero sized tensors during tracing.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) # To get the example output from map() we will need to provide at least one sample to @@ -2074,8 +2233,11 @@ def _call_function( # executorch_call_delegate sits at a higher level than dynamo, but # there's no real solution to this issue yet. if len(kwargs) > 0: - unimplemented( - "executorch_call_delegate: kwargs arguments were not enabled." + unimplemented_v2( + gb_type="executorch_call_delegate: kwargs not supported", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"executorch_call_delegate expects no keyword arguments (got {len(kwargs)})", + hints=[], ) if isinstance(args[0], variables.NNModuleVariable): lowered_module = tx.output.get_submodule(args[0].module_key) @@ -2131,10 +2293,13 @@ def call_function( self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] ) -> VariableTracker: if not torch._dynamo.config.inline_inbuilt_nn_modules: - unimplemented( - "torch.func.functional_call capture is disabled, " - "it can be turned on by setting " - "`torch._dynamo.config.inline_inbuilt_nn_modules=True`" + unimplemented_v2( + gb_type="torch.func.functional_call capture is disabled", + context="", + explanation="torch.func.functional_call capture is disabled", + hints=[ + "Set `torch._dynamo.config.inline_inbuilt_nn_modules=True` to enable.", + ], ) return super().call_function(tx, args, kwargs) @@ -2238,7 +2403,14 @@ def _call_function( ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") if len(p_kwargs) > 0: - unimplemented("kwargs should have been flattened into lifted args") + unimplemented_v2( + gb_type="WrapHigherOrderVariable: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) flat_example_value = pytree.tree_map_only( torch.fx.Proxy, @@ -2266,14 +2438,26 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if kwargs: - unimplemented( - f"wrap_with_set_grad_enabled: Got unexpected kwargs: {list(kwargs.keys())}" + unimplemented_v2( + gb_type="wrap_with_set_grad_enabled: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_set_grad_enabled expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) grad_enabled, fn_var, *rest_args = args if not isinstance(grad_enabled, ConstantVariable): - unimplemented("grad_enabled must be a constant") + unimplemented_v2( + gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", + context=str(grad_enabled), + explanation="wrap_with_set_grad_enabled expects grad_enabled argument to be a constant.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) _check_supported_callable_arg(tx, fn_var, "enable_grad_fn") @@ -2294,8 +2478,11 @@ def call_function( ) if len(body_lifted_freevars) > 0: - unimplemented( - f"wrap_with_set_grad_enabled: Got unexpected freevars {body_lifted_freevars}" + unimplemented_v2( + gb_type="wrap_with_set_grad_enabled: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_set_grad_enabled expects no freevars.", + hints=[], ) body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) @@ -2338,16 +2525,27 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if kwargs: - unimplemented( - f"wrap_with_autocast: Got unexpected kwargs: {list(kwargs.keys())}" + unimplemented_v2( + gb_type="wrap_with_autocast: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"wrap_with_autocast expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args for arg in [device_type, dtype, enabled, cache_enabled]: if not isinstance(arg, ConstantVariable): - unimplemented( - "device_type, dtype, enabled, cache_enabled must be constants" + unimplemented_v2( + gb_type="wrap_with_autocast: expected constant arg", + context=str(args), + explanation="wrap_with_autocast expects device_type, dtype, enabled, " + "and cache_enabled arguments to be constants.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], ) _check_supported_callable_arg(tx, fn_var, "autocast") @@ -2374,8 +2572,11 @@ def call_function( ) if len(body_lifted_freevars) > 0: - unimplemented( - f"wrap_with_autocast: Got unexpected freevars {body_lifted_freevars}" + unimplemented_v2( + gb_type="wrap_with_autocast: unexpected freevars", + context=str(body_lifted_freevars), + explanation="wrap_with_autocast expects no freevars.", + hints=[], ) body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) @@ -2406,7 +2607,7 @@ def call_function( class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( - reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." + reason="hints_wrapper doesn't work unless it is captured completely with torch.compile." ) def _call_function( self, tx, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]" @@ -2414,27 +2615,27 @@ def _call_function( _check_supported_callable_arg(tx, args[0], "body_fn") # inputs - if len(args) != 3: - unimplemented( - f"Expected 3 arguments but got {len(args)}.\n" - f"Usage: hints_wrapper(body_fn, args, kwargs, hints).\n" - f"kwargs required to be provided explicitly." + if ( + len(args) != 3 + or not isinstance(args[1], (ListVariable, TupleVariable)) + or not isinstance(args[2], ConstDictVariable) + or len(kwargs) != 1 + or "hints" not in kwargs + ): + unimplemented_v2( + gb_type="hints_wrapper: improper args/kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"hints_wrapper expects 3 positional arguments (got {len(args)}) " + f"and 1 keyword argument (got {len(kwargs)}). " + "Usage: hints_wrapper(body_fn, args, kwargs, hints=...). " + "args is expected to be list/tuple and kwargs is expected to be a dict.", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) - if not isinstance(args[1], (ListVariable, TupleVariable)): - unimplemented( - f"Expected a tuple but got {args[1].python_type()}", - ) operands = args[1].unpack_var_sequence(tx) - if not isinstance(args[2], ConstDictVariable): - unimplemented( - f"Expected a dict but got {args[2].python_type()}", - ) - - if "hints" not in kwargs: - raise IncorrectUsage("hints_wrapper - key hints not provided") - ( (body_r, treespec), body_graph, @@ -2487,7 +2688,14 @@ def _call_function( from .builder import wrap_fx_proxy if len(kwargs) > 0: - unimplemented("out_dtype does not handle kwargs") + unimplemented_v2( + gb_type="out_dtype: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"out_dtype expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) p_args = tuple(arg.as_proxy() for arg in args) op = p_args[0] @@ -2526,11 +2734,23 @@ def _call_function( # TODO (tmanlaibaatar) support pytree here for arg in unpacked_sequence: if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): - unimplemented("strict_mode HOO only works for flat inputs for now") + unimplemented_v2( + gb_type="strict_mode: improper args", + context=f"args: {args}, kwargs: {kwargs}", + explanation="strict_mode higher order op expects flat inputs (list/tuple/dict)", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) if kwargs: - unimplemented( - f"strict_mode HOO received unexpected kwargs: {list(kwargs.keys())}" + unimplemented_v2( + gb_type="strict_mode: unexpected kwargs", + context=f"args: {args}, kwargs: {kwargs}", + explanation=f"strict_mode higher order op expects no keyword arguments (got {len(kwargs)}).", + hints=[ + *graph_break_hints.USER_ERROR, + ], ) ( @@ -3048,7 +3268,12 @@ def bwd(ctx, grad, x): ) fwd_args = [fwd_fn.obj, ctx, *args] else: - unimplemented("non-function or method") + unimplemented_v2( + gb_type="autograd.Function.apply: non-function or method forward", + context=str(self.fwd_graph), + explanation="Expected forward function to be a function or method.", + hints=[], + ) # Speculate subgraph on the fwd (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( @@ -3068,7 +3293,14 @@ def bwd(ctx, grad, x): "_materialize_non_diff_grads" in tx.output.side_effects.store_attr_mutations[ctx] ): - unimplemented("NYI") + unimplemented_v2( + gb_type="autograd.Function.apply: _materialize_non_diff_grads mutation", + context="", + explanation="Mutations to autograd.Function.ctx._materialize_non_diff_grads are not supported.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( tx.output, @@ -3096,7 +3328,12 @@ def bwd(ctx, grad, x): ) bwd_args = [bwd_fn.obj, *bwd_args] else: - unimplemented("non-function or method") + unimplemented_v2( + gb_type="autograd.Function.apply: non-function or method backward", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) def is_strict_for(v: VariableTracker): if isinstance(v, variables.TensorVariable): @@ -3147,7 +3384,12 @@ def is_strict_for(v: VariableTracker): UserDefinedClassVariable(self.bwd_graph.__class__), ) else: - unimplemented("non-function or method") + unimplemented_v2( + gb_type="autograd.Function.apply: non-function or method backward (2)", + context=str(self.bwd_graph), + explanation="Expected backward function to be a function or method.", + hints=[], + ) with mock.patch( "torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops", @@ -3499,7 +3741,14 @@ def _call_function( ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph") if len(p_kwargs) > 0: - unimplemented("kwargs should have been flattened into lifted args") + unimplemented_v2( + gb_type="invoke_subgraph: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) flat_example_value = pytree.tree_map_only( torch.fx.Proxy, From bd7e18bc57969be0facdda0a5fda872af934bc8b Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 6 Nov 2025 14:51:20 -0800 Subject: [PATCH 185/651] [dynamo] unimplemented -> unimplemented_v2 in torch/_subclasses/meta_utils.py (#167159) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167159 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos ghstack dependencies: #167001, #167146 --- torch/_subclasses/meta_utils.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index f56800367af45..ded569f70ef64 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1413,10 +1413,14 @@ def tensor_visitor_fn( # TODO: Handle this better in Dynamo? # There are checks there now, but this can still be triggered by a dense # tensor graph input that is a view of a strided NT. - from torch._dynamo.exc import unimplemented - - unimplemented( - "strided nested tensors are not supported by meta conversion" + from torch._dynamo.exc import unimplemented_v2 + + # NOTE this graph break will NOT be present in Dynamo's graph break registry + unimplemented_v2( + gb_type="attempted to apply meta conversion to strided nested tensor", + context=str(t), + explanation="This is not supported.", + hints=[], ) elif t.is_mkldnn: is_leaf = t.is_leaf @@ -1450,10 +1454,13 @@ def tensor_visitor_fn( r = self._backward_error(r) elif t.is_functorch_wrapped: if t.is_view: - from torch._dynamo.exc import unimplemented + from torch._dynamo.exc import unimplemented_v2 - unimplemented( - "view functorch tensors are not supported by meta conversion" + unimplemented_v2( + gb_type="attempted to apply meta conversion to view functorch tensor", + context=str(t), + explanation="This is not supported.", + hints=[], ) # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) From 3261149aa3897230237660ccc2e65ed40c8bf543 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 6 Nov 2025 14:51:20 -0800 Subject: [PATCH 186/651] [dynamo] remove old unimplemented() call (#167149) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167149 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos ghstack dependencies: #167001, #167146, #167159 --- torch/_dynamo/exc.py | 40 +++------------------------------------- 1 file changed, 3 insertions(+), 37 deletions(-) diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 857c694ad92c6..d252f7c2a3b36 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -28,7 +28,6 @@ import json import logging -import os import re import textwrap import typing @@ -451,42 +450,6 @@ def handle_observed_exception(tx: Any) -> None: ) -def unimplemented_with_warning( - e: Exception, code: types.CodeType, msg: str -) -> NoReturn: - # This function calls unimplemented internally and eventually graph breaks - # or falls to eager. unimplemented itself does not print any user warnings, - # i.e., its very silent. This helper function is intended when an error is - # encountered in the torch.compile stack which is worth showing as warning - # to the user. For example, if AOT Autograd backend fails with a fake tensor - # exception, its ok to fallback to eager but not silently. Here, we can use - # this function to log the message and the stack trace. - graph_break_msg = format_error_msg_verbose(e, code) - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "dynamo_graph_break_reason", - "encoding": "string", - }, - payload_fn=lambda: graph_break_msg, - ) - graph_breaks_log.debug("%s", graph_break_msg) - log.warning(msg) - unimplemented(msg, from_exc=e) - - -_NOTHING = object() - - -def unimplemented( - msg: str, *, from_exc: Any = _NOTHING, case_name: Optional[str] = None -) -> NoReturn: - assert msg != os.environ.get("BREAK", False) - if from_exc is not _NOTHING: - raise Unsupported(msg, case_name=case_name) from from_exc - raise Unsupported(msg, case_name=case_name) - - def unimplemented_v2_with_warning( e: Exception, code: types.CodeType, @@ -587,6 +550,9 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]: return None +_NOTHING = object() + + # TODO replace old unimplemented later def unimplemented_v2( gb_type: str, From faba6e205f00939a5892404df989a62595e74e29 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Thu, 6 Nov 2025 21:45:58 -0800 Subject: [PATCH 187/651] [pallas backend] use dlpack directly (#167243) previous version does not work on jax 0.8 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167243 Approved by: https://github.com/yf225, https://github.com/jansel --- test/inductor/test_pallas.py | 23 +++++++---------------- torch/_inductor/codegen/pallas.py | 9 ++------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index 2d4e6af002ab0..3ba84e8cd2b8c 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -305,8 +305,8 @@ def contiguous_mul(x): expected = contiguous_mul(x) self.assertEqual(result, expected) - # Test 3: Non-contiguous views will fail at runtime with JAX/Pallas - # This demonstrates that the Pallas backend requires contiguous memory layout + # Test 3: Non-contiguous views should work with the simplified dlpack approach + # The direct dlpack conversion handles non-contiguous tensors correctly def operate_on_tensor(x): return x.sin() @@ -319,21 +319,12 @@ def operate_on_tensor(x): x_t = x.t() # Non-contiguous view self.assertFalse(x_t.is_contiguous()) - # This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack - # The error indicates that our contiguous-only approach is correct - with self.assertRaises((RuntimeError, Exception)) as cm: - result = compiled(x_t) - - # Verify the error is related to layout/contiguous issues - error_msg = str(cm.exception) - self.assertTrue( - "layout" in error_msg.lower() - or "contiguous" in error_msg.lower() - or "non-default" in error_msg.lower(), - f"Expected layout/contiguous error, got: {error_msg}", - ) + # With the simplified dlpack approach, non-contiguous tensors now work + result = compiled(x_t) + expected = operate_on_tensor(x_t) + self.assertEqual(result, expected) - # But if we make it contiguous first, it should work + # Contiguous tensors should also continue to work x_t_contiguous = x_t.contiguous() self.assertTrue(x_t_contiguous.is_contiguous()) result = compiled(x_t_contiguous) diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index da437a4e8ee3c..8587368407323 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -291,7 +291,6 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove import jax import jax.numpy as jnp from jax.experimental import pallas as pl - from torch.utils import dlpack as torch_dlpack """, strip=True, ) @@ -330,9 +329,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove # Convert inputs to JAX arrays code.writeline("# Convert Torch -> JAX for inputs") for inp in input_params: - code.writeline( - f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))" - ) + code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})") # Get output spec from PyTorch tensor code.writeline("# Prepare output spec from PyTorch tensor") @@ -362,9 +359,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove # Copy result back code.writeline("# Copy result back into the provided torch output tensor") - code.writeline( - "res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))" - ) + code.writeline("res_t = torch.from_dlpack(res)") code.writeline(f"{output_param}.copy_(res_t)") return code.getvalue() From 5b2ad2d5dcdb4cc3a41a5d8fc627580238414c22 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 6 Nov 2025 14:44:23 -0800 Subject: [PATCH 188/651] [user-streams] Add fallbacks for record and wait event (#167260) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167260 Approved by: https://github.com/shunting314 ghstack dependencies: #167175, #167176, #167180, #167195 --- test/dynamo/test_streams.py | 15 +++++++++++++++ torch/_inductor/lowering.py | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index b736a5750e3a6..64fb41afa531e 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -3,6 +3,7 @@ import re import unittest import weakref +from unittest.mock import patch import torch import torch._dynamo.test_case @@ -491,6 +492,20 @@ def test_run_opcheck_wait_record(self): torch.accelerator.set_stream(original_stream) reset_user_object_tracking() + @requires_cuda + def test_inductor_lowering(self): + with patch("torch._inductor.config.implicit_fallbacks", False): + + @torch.compile() + def fn(x): + e = torch.Event() + x += x + 1 + e.record() + return x + + inp = (torch.ones(2, 2, device="cuda"),) + fn(*inp) + def test_is_marked_side_effectful(self): self.assertIn( torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 2df224caf61a9..3e6ffd46f80f1 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2379,6 +2379,10 @@ def warn_triton_random(): fallback_randn_generator = fallback_handler(aten.randn.generator) make_fallback(aten.randint) +# TODO: mlazos reevaluate if we want to codegen something different +make_fallback(torch.ops.streams.record_event.default) +make_fallback(torch.ops.streams.wait_event.default) + @register_lowering(aten.rand) def rand(*args, **kwargs): From 7318ed627beac8bb9e9be1a16fa4441ebfeade55 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 6 Nov 2025 14:44:24 -0800 Subject: [PATCH 189/651] [user-streams] Trace events with the new ops (#167177) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167177 Approved by: https://github.com/anijain2305 ghstack dependencies: #167175, #167176, #167180, #167195, #167260 --- test/dynamo/test_ctx_manager.py | 16 +++++- test/dynamo/test_streams.py | 31 ++++++++++++ torch/_dynamo/utils.py | 4 ++ torch/_dynamo/variables/builder.py | 21 ++++++-- torch/_dynamo/variables/streams.py | 66 ++++++++++++++++++++++++- torch/_dynamo/variables/torch.py | 2 +- torch/_dynamo/variables/user_defined.py | 28 +++++++++++ 7 files changed, 159 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 0433354b953b9..780a660227bf1 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -408,6 +408,9 @@ def fn(x, s0, s1): self.assertEqual(ref0, res0) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skip( + "Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257" + ) def test_cuda_event_reconstruct(self): def fn(x): e = torch.cuda.Event() @@ -425,6 +428,9 @@ def fn(x): self.assertEqual(cnts.op_count, 3) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skip( + "Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257" + ) def test_cuda_event_across_graph_break(self): def fn(x): e = torch.cuda.Event() @@ -446,9 +452,12 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref[0], res[0]) self.assertEqual(cnts.frame_count, 2) - self.assertEqual(cnts.op_count, 9) + self.assertEqual(cnts.op_count, 10) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skip( + "Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257" + ) def test_cuda_event_created_outside_of_graph(self): user_stream = torch.cuda.Stream() event = torch.cuda.Event() @@ -478,9 +487,12 @@ def run_iters(fn, compile=False): res = run_iters(func, compile=True) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 3) + self.assertEqual(cnts.op_count, 4) @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + @unittest.skip( + "Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257" + ) def test_cuda_event_method_create_stream_outside_of_compile(self): def fn(x, cur_stream, new_stream): x = torch.mul(x, 1) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 64fb41afa531e..9ded5522a41b4 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -446,6 +446,37 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): """, ) + @requires_cuda + def test_event_tracing(self): + def fn(x) -> None: + e = torch.Event() + e.record() + x.add_(1) + return x + + inp = (torch.ones(2, 2, device="cuda"),) + ( + _, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]"): + # + record_event = torch.ops.streams.record_event.default(0, 1); record_event = None + + # + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1) + copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = None + return (copy_,) +""", + ) + @requires_cuda def test_run_opcheck_fork_join(self): from torch._dynamo.variables.streams import fork_stream, join_stream diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 644081ab68579..4bff421c7d385 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4768,6 +4768,10 @@ def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream: return torch._C.Stream(*args, **kwargs) +def build_event(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Event: + return torch._C.Event(*args, **kwargs) + + class CompileTimeInstructionCounter: _counter: int = 0 _id: int = -1 diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index e436a07bd0dcb..0c74055973bf8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1083,6 +1083,7 @@ def build_key_value(i, k, v): return EventVariable( event_proxy, value, + index, source=self.source, ) elif ( @@ -3004,16 +3005,28 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe return SymNodeVariable(proxy, example_value, **options) elif ( isinstance(example_value, torch.Stream) - and proxy.node.target == get_external_object_by_index + and proxy.node.target is get_external_object_by_index ) or proxy.node.target in [ device_interface.current_stream for _, device_interface in get_registered_device_interfaces() ]: set_example_value(proxy.node, example_value) index = None - if proxy.node.target == get_external_object_by_index: + if proxy.node.target is get_external_object_by_index: index = proxy.node.args[0] return StreamVariable(proxy, example_value, index, **options) + elif ( + isinstance(example_value, torch.Event) + and proxy.node.target is get_external_object_by_index + ) or proxy.node.target in [ + device_interface.current_stream + for _, device_interface in get_registered_device_interfaces() + ]: + index = None + if proxy.node.target is get_external_object_by_index: + index = proxy.node.args[0] + set_example_value(proxy.node, example_value) + return EventVariable(proxy, example_value, index, **options) elif ( inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, torch.Event) @@ -3022,7 +3035,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe for _, device_interface in get_registered_device_interfaces() ]: set_example_value(proxy.node, example_value) - return EventVariable(proxy, example_value, **options) + return EventVariable(proxy, example_value, None, **options) elif proxy.node.target == "query" and proxy.node.op == "call_method": set_example_value(proxy.node, example_value) return ConstantVariable(example_value, **options) @@ -3033,7 +3046,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe and proxy.node.op == "call_method" ): set_example_value(proxy.node, example_value) - return EventVariable(proxy, example_value, **options) + return EventVariable(proxy, example_value, None, **options) elif isinstance(example_value, int) and ( proxy.node.target in [ diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 6aa6e43a2a00e..79a0d0eb9ba23 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -326,12 +326,19 @@ def fn(index: int, codegen: "PyCodegen") -> None: class EventVariable(VariableTracker): - def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None: + def __init__( + self, + proxy: Proxy, + value: torch.Event, + user_object_index: Optional[int], + **kwargs: Any, + ) -> None: if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value super().__init__(**kwargs) self.proxy = proxy self.value = value + self.user_object_index = user_object_index def call_method( self, @@ -343,7 +350,29 @@ def call_method( from ..utils import proxy_args_kwargs from .builder import wrap_fx_proxy_cls - if name in ("wait", "record", "synchronize"): + if name == "wait": + tx.output.create_proxy( + "call_function", + torch.ops.streams.wait_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "record": + tx.output.create_proxy( + "call_function", + torch.ops.streams.record_event, + ( + self.user_object_index, + EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + ), + {}, + ) + return ConstantVariable(None) + elif name == "synchronize": tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self] + args, kwargs) ) @@ -373,6 +402,39 @@ def call_method( def as_proxy(self) -> Proxy: return self.proxy + @staticmethod + def _get_stream_arg( + tx: "InstructionTranslator", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "StreamVariable": + stream_arg = None + if args: + stream_arg = args[0] + elif kwargs: + stream_arg = kwargs.get("stream") + + if not stream_arg: + stream_arg = tx.symbolic_stream_state.cur_stream() + + return stream_arg # type: ignore[return-value] + + @staticmethod + def make_construct_in_graph_event_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_event" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) + + return fn + def reconstruct(self, codegen: "PyCodegen") -> None: # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 30c1b8c2cf186..e8e246be968eb 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1270,7 +1270,7 @@ def handle_get_device_module(self, tx, *args, **kwargs): # pyrefly: ignore [unbound-name] return VariableTracker.build(tx, module, new_source) - @register(torch.accelerator.current_stream) + @register(torch.accelerator.current_stream, torch.cuda.current_stream) def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): unimplemented_v2( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index cea5be48a6b30..a65ee6b1e0bf6 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -839,6 +839,34 @@ def deque_signature(iterable=None, maxlen=None): "call_function", get_external_object_by_index, (ind,), {} ), ) + elif issubclass(self.value, torch.Event): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created event for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + event = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import EventVariable + + ind = register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + ) else: tensor_variable = wrap_fx_proxy( tx=tx, From 57dd6a0656e0189f37dac1298a5659743b2fc044 Mon Sep 17 00:00:00 2001 From: Randy Shuai Date: Fri, 7 Nov 2025 07:08:48 +0000 Subject: [PATCH 190/651] [OC][Torch] Extend autotune options for OC OBA 200x shapes (#166931) Summary: Add four best configs for shapes of the OC OBA 200x model: ``` M=2048 N=2048 K=12288 triton_mm_35 0.1526 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=52416 K=1536 triton_mm_12 0.4604 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=12288 K=2048 triton_mm_9 0.1444 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=8, num_consumer_groups=0, num_buffers_warp_spec=0 M=2048 N=2048 K=52416 triton_mm_35 0.6505 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, USE_FAST_ACCUM=True, kpack=2, matrix_instr_nonkdim=16, waves_per_eu=0, num_stages=2, num_warps=4, num_consumer_groups=0, num_buffers_warp_spec=0 ``` Test Plan: Run tritonbench for torch fp8(_scaled_mm) for all above shapes, e.g. ``` TRITON_PRINT_AUTOTUNING=1 buck2 run mode/opt-amd-gpu -c fbcode.enable_gpu_sections=true //pytorch/tritonbench:run -- --op fp8_gemm --only pt2_fp8_gemm --metrics tflops,accuracy --m 2048 --n 2048 --k 12288 ``` Differential Revision: D86158497 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166931 Approved by: https://github.com/jananisriram --- torch/_inductor/template_heuristics/triton.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 8cbbf5073d5ef..9df8d114ef67b 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -422,8 +422,12 @@ def __init__(self) -> None: GemmConfig(32, 256, 64, 6, 4), GemmConfig(64, 16, 256, 5, 4), GemmConfig(64, 32, 256, 5, 4), + GemmConfig(64, 128, 128, 2, 4), GemmConfig(64, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 2, 4), GemmConfig(128, 256, 128, 4, 8), + GemmConfig(256, 128, 128, 2, 4), + GemmConfig(256, 128, 128, 2, 8), ] self.scaled_persistent_mm_configs: list[BaseConfig] = [ From 0968e74266273cf8bcdd8ac4488c6ef25240ce03 Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Fri, 7 Nov 2025 07:37:34 +0000 Subject: [PATCH 191/651] [ROCm][CI] Run PR-Based workflow runs on mi300 nodes. (#167225) This PR is meant to swap the PR-based ciflow tags from the mi200 nodes (less stable) to the mi300 nodes (more stable). This will ensure that developers see consistent testing on their PRs as well as on main. This PR does all of the following: - Rename rocm.yml to rocm-mi200.yml : for clarity - Add ciflow/rocm-mi200 trigger to rocm-mi200.yml : for devs who want to opt-in to single-GPU unit tests on MI200 - Move ciflow/rocm trigger from rocm-mi200.yml to rocm-mi300.yml : so PRs target MI300 runners by default - Rename inductor-rocm.yml to inductor-rocm-mi200.yml : for clarity - Remove ciflow/inductor-rocm trigger from inductor-rocm-mi200.yml : prevent MI200 inductor config unit tests being triggered by default - Add ciflow/inductor-rocm-mi200 trigger to inductor-rocm-mi200.yml : for devs who want to opt-in to inductor config unit tests on MI200 - Move ciflow/periodic trigger from periodic-rocm-mi200.yml to periodic-rocm-mi300.yml : so PRs target MI300 runners by default Pull Request resolved: https://github.com/pytorch/pytorch/pull/167225 Approved by: https://github.com/jeffdaily, https://github.com/huydhn Co-authored-by: Jithun Nair --- .github/pytorch-probot.yml | 5 ++++- .../workflows/{inductor-rocm.yml => inductor-rocm-mi200.yml} | 2 +- .github/workflows/inductor-rocm-mi300.yml | 1 + .github/workflows/periodic-rocm-mi200.yml | 1 - .github/workflows/periodic-rocm-mi300.yml | 1 + .github/workflows/{rocm.yml => rocm-mi200.yml} | 2 +- .github/workflows/rocm-mi300.yml | 1 + .github/workflows/upload-test-stats.yml | 4 ++-- 8 files changed, 11 insertions(+), 6 deletions(-) rename .github/workflows/{inductor-rocm.yml => inductor-rocm-mi200.yml} (98%) rename .github/workflows/{rocm.yml => rocm-mi200.yml} (98%) diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index c15ba606398f6..fe6881d9318d5 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -2,8 +2,8 @@ tracking_issue: 24422 ciflow_tracking_issue: 64124 ciflow_push_tags: - ciflow/b200 -- ciflow/b200-symm-mem - ciflow/b200-distributed +- ciflow/b200-symm-mem - ciflow/binaries - ciflow/binaries_libtorch - ciflow/binaries_wheel @@ -22,6 +22,8 @@ ciflow_push_tags: - ciflow/inductor-perf-test-nightly-xpu - ciflow/inductor-periodic - ciflow/inductor-rocm +- ciflow/inductor-rocm-mi200 +- ciflow/inductor-rocm-mi300 - ciflow/linux-aarch64 - ciflow/mps - ciflow/nightly @@ -33,6 +35,7 @@ ciflow_push_tags: - ciflow/quantization-periodic - ciflow/riscv64 - ciflow/rocm +- ciflow/rocm-mi200 - ciflow/rocm-mi300 - ciflow/rocm-mi355 - ciflow/rocm-navi31 diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm-mi200.yml similarity index 98% rename from .github/workflows/inductor-rocm.yml rename to .github/workflows/inductor-rocm-mi200.yml index 8dbc785e20f16..c33104d39dcbd 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm-mi200.yml @@ -7,7 +7,7 @@ on: branches: - release/* tags: - - ciflow/inductor-rocm/* + - ciflow/inductor-rocm-mi200/* workflow_dispatch: concurrency: diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index 732ec7eb85f3e..dee10a0db3c16 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -7,6 +7,7 @@ on: - release/* tags: - ciflow/inductor-rocm/* + - ciflow/inductor-rocm-mi300/* workflow_dispatch: concurrency: diff --git a/.github/workflows/periodic-rocm-mi200.yml b/.github/workflows/periodic-rocm-mi200.yml index 6b65bf05cbde0..18e7b60570bf8 100644 --- a/.github/workflows/periodic-rocm-mi200.yml +++ b/.github/workflows/periodic-rocm-mi200.yml @@ -11,7 +11,6 @@ on: - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests push: tags: - - ciflow/periodic/* - ciflow/periodic-rocm-mi200/* branches: - release/* diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index 4d8890e69fc73..ce68ee8bc8e03 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -11,6 +11,7 @@ on: - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests push: tags: + - ciflow/periodic/* - ciflow/periodic-rocm-mi300/* branches: - release/* diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm-mi200.yml similarity index 98% rename from .github/workflows/rocm.yml rename to .github/workflows/rocm-mi200.yml index 6f37d3e4f65a4..f40910ae0f61f 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm-mi200.yml @@ -5,7 +5,7 @@ on: branches: - release/* tags: - - ciflow/rocm/* + - ciflow/rocm-mi200/* workflow_dispatch: schedule: - cron: 29 8 * * * # about 1:29am PDT diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index c50111d068d24..d20b37be20876 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -6,6 +6,7 @@ on: - main - release/* tags: + - ciflow/rocm/* - ciflow/rocm-mi300/* workflow_dispatch: schedule: diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 24c3ab3db84f3..3bd88d717908b 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -13,13 +13,13 @@ on: - slow - unstable-periodic - inductor-periodic - - rocm + - rocm-mi200 - rocm-mi300 - rocm-mi355 - inductor-micro-benchmark - inductor-micro-benchmark-x86 - inductor-cu124 - - inductor-rocm + - inductor-rocm-mi200 - inductor-rocm-mi300 - mac-mps - linux-aarch64 From 35d2da32bd88f10eac038d21c4f753b2bc171e1b Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Fri, 7 Nov 2025 07:38:56 +0000 Subject: [PATCH 192/651] [ROCm][CI] Separate out rocm from slow workflow (#167262) Running slow.yml on every commit is straining our limited MI200 capacity. Reducing the frequency in line with other MI200-based workflows as per https://github.com/pytorch/pytorch/pull/167220 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167262 Approved by: https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- .github/pytorch-probot.yml | 1 + .github/workflows/slow-rocm-mi200.yml | 81 +++++++++++++++++++++++++ .github/workflows/slow.yml | 30 --------- .github/workflows/upload-test-stats.yml | 1 + 4 files changed, 83 insertions(+), 30 deletions(-) create mode 100644 .github/workflows/slow-rocm-mi200.yml diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index fe6881d9318d5..8de0df02a132c 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -41,6 +41,7 @@ ciflow_push_tags: - ciflow/rocm-navi31 - ciflow/s390 - ciflow/slow +- ciflow/slow-rocm-mi200 - ciflow/torchbench - ciflow/triton_binaries - ciflow/trunk diff --git a/.github/workflows/slow-rocm-mi200.yml b/.github/workflows/slow-rocm-mi200.yml new file mode 100644 index 0000000000000..c564857dca9ce --- /dev/null +++ b/.github/workflows/slow-rocm-mi200.yml @@ -0,0 +1,81 @@ +# This workflow is dedicated to host slow jobs that are run only periodically because +# they are too slow to run in every commit. The list of slow tests can be found in +# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json +name: slow-rocm-mi200 + +on: + push: + branches: + - release/* + tags: + - ciflow/slow/* + - ciflow/slow-rocm-mi200/* + schedule: + - cron: 0 */3 * * * + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + llm-td: + if: github.repository_owner == 'pytorch' + name: before-test + uses: ./.github/workflows/llm_td_retrieval.yml + permissions: + id-token: write + contents: read + + target-determination: + name: before-test + uses: ./.github/workflows/target_determination.yml + needs: llm-td + permissions: + id-token: write + contents: read + + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-rocm-py3_10-build: + name: linux-jammy-rocm-py3.10 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-rocm-py3.10 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + sync-tag: rocm-build + test-matrix: | + { include: [ + { config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] }, + { config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] }, + ]} + secrets: inherit + + linux-jammy-rocm-py3_10-test: + permissions: + id-token: write + contents: read + name: linux-jammy-rocm-py3.10 + uses: ./.github/workflows/_rocm-test.yml + needs: + - linux-jammy-rocm-py3_10-build + - target-determination + with: + build-environment: linux-jammy-rocm-py3.10 + docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index d4992a2ddb2cf..c14caee9a336c 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -105,36 +105,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }} secrets: inherit - linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 - test-matrix: | - { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] }, - { config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] }, - ]} - secrets: inherit - - linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3.10 - uses: ./.github/workflows/_rocm-test.yml - needs: - - linux-jammy-rocm-py3_10-build - - target-determination - with: - build-environment: linux-jammy-rocm-py3.10 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-py3_10-clang18-asan-build: name: linux-jammy-py3.10-clang18-asan uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 3bd88d717908b..39b82acdfd274 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -11,6 +11,7 @@ on: - inductor - unstable - slow + - slow-rocm-mi200 - unstable-periodic - inductor-periodic - rocm-mi200 From 05b8214e6ab3e59320a85772f182d793b1c24057 Mon Sep 17 00:00:00 2001 From: Yarong Mu Date: Fri, 7 Nov 2025 08:23:02 +0000 Subject: [PATCH 193/651] Added a couple of utils for Pallas TPU backend. (#167264) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167264 Approved by: https://github.com/oulgen --- torch/utils/_pallas.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py index 25cc635dbb178..2d93e7f32c58e 100644 --- a/torch/utils/_pallas.py +++ b/torch/utils/_pallas.py @@ -57,6 +57,21 @@ def has_jax_cuda_backend() -> bool: return False +@functools.cache +def has_jax_tpu_backend() -> bool: + """Check if JAX has TPU backend support.""" + if not has_jax_package(): + return False + try: + import jax # type: ignore[import-not-found] + + # Check if TPU backend is available + devices = jax.devices("tpu") + return len(devices) > 0 + except Exception: + return False + + @functools.cache def has_pallas() -> bool: """ @@ -65,18 +80,22 @@ def has_pallas() -> bool: Requirements: - JAX package installed - Pallas (jax.experimental.pallas) available - - CUDA backend available (for GPU support) + - A compatible backend (CUDA or TPU) is available in both PyTorch and JAX. """ if not has_pallas_package(): return False - # Only enable Pallas if CUDA is available - # (Pallas primarily targets GPU workloads) - if not torch.cuda.is_available(): - return False + # Check for is CUDA is available or if JAX has GPU/CUDA backend + has_cuda = torch.cuda.is_available() and has_jax_cuda_backend() - # Check if JAX has GPU/CUDA backend - if not has_jax_cuda_backend(): - return False + # Check for TPU backend + has_tpu_torch = False + try: + import torch_xla.core.xla_model as xm + + has_tpu_torch = xm.xla_device_count() > 0 + except ImportError: + pass + has_tpu = has_tpu_torch and has_jax_tpu_backend() - return True + return has_cuda or has_tpu From 4cf1d1af225e061b38ec53e2ac7c6ed31049cf9e Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Fri, 7 Nov 2025 08:55:47 +0000 Subject: [PATCH 194/651] [Inductor][Tritonparse] Ensure inductor meta has config_args (#167261) Summary: Before calling the tritonparse hook with `config_args`, ensure that we set `config_args` within `inductor_meta`. This way, even if it is not set, the hook still gets run and we can at least get the launch arguments. Test Plan: Tritonparse tests Differential Revision: D86463732 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167261 Approved by: https://github.com/FindHao --- torch/_inductor/runtime/triton_heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c2709073a64c1..cf20456407d48 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -832,7 +832,7 @@ def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: # only add inductor_args if the hook takes it sig = inspect.signature(hook) params = sig.parameters - if "inductor_args" in params: + if "inductor_args" in params and "config_args" in self.inductor_meta: call_kwargs["inductor_args"] = self.inductor_meta["config_args"] hook(**call_kwargs) From 3d59e8aadf085186a3a1800da1001e93b90cdb07 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 09:21:48 +0000 Subject: [PATCH 195/651] [14/N] Apply ruff UP035 rule (#167208) This PR continues to apply the `UP035` ruff rule and add `collections.abc` to dynamo checks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167208 Approved by: https://github.com/mlazos --- torch/_dynamo/config.py | 3 ++- torch/_functorch/config.py | 2 +- torch/_inductor/config.py | 3 ++- torch/utils/_config_module.py | 8 +++++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0c95408401c79..8682ac1cb3a44 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -14,8 +14,9 @@ import os import sys import tempfile +from collections.abc import Callable from os.path import abspath, dirname -from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, TYPE_CHECKING, Union from torch._environment import is_fbcode from torch.utils._config_module import Config, get_tristate_env, install_config_module diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 3dd2529b1b107..ae9221c38cb83 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Callable +from collections.abc import Callable """ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bfa854b37030d..e00cacb59abe6 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1,6 +1,7 @@ import os import sys -from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union +from collections.abc import Callable +from typing import Any, Literal, Optional, TYPE_CHECKING, Union import torch import torch._inductor.custom_graph_pass diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index f302a10b8338e..33546eb01b6a0 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -175,7 +175,13 @@ def visit( if ( key.startswith("__") or isinstance(value, (ModuleType, FunctionType)) - or (hasattr(value, "__module__") and value.__module__ == "typing") + or ( + hasattr(value, "__module__") + and ( + value.__module__ == "typing" + or value.__module__.startswith("collections.abc") + ) + ) # Handle from torch.utils._config_module import Config or (isinstance(value, type) and issubclass(value, _Config)) ): From 5a9ae7cefe679ff925a0aa7b9f5782fc93d4ef29 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 5 Nov 2025 15:41:17 -0800 Subject: [PATCH 196/651] [CP] Correctly compile create_cp_block_mask (#167153) Currently we re-compile create_block_mask every time, which is not very efficient and the global compilation also causes some issues. This PR lazily compile the create_block_mask and does it only once. Fixes https://github.com/pytorch/pytorch/issues/167064 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167153 Approved by: https://github.com/drisspg, https://github.com/XilunWu --- .../experimental/_context_parallel/_attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/torch/distributed/tensor/experimental/_context_parallel/_attention.py index 09a86081df522..b1903e211a1c1 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_attention.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -1032,9 +1032,7 @@ def _disable_context_parallel_dispatcher_impl() -> None: _disable_cp_dtensor_dispatcher() -_compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True -) +_compiled_create_block_mask = None def _context_parallel_buffers( @@ -1187,9 +1185,12 @@ def _create_cp_block_mask( f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. " ) - compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True - ) + global _compiled_create_block_mask + if _compiled_create_block_mask is None: + _compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True + ) + compiled_create_block_mask = _compiled_create_block_mask def _rewrite_mask_mod( mask_mod: _mask_mod_signature, From 341e924981f2f747d5ab19a5bcaa2978c37e2795 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 09:47:14 +0000 Subject: [PATCH 197/651] [4/N] Use key in dict for existence checks (#167285) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167285 Approved by: https://github.com/mlazos --- test/test_bundled_inputs.py | 2 +- test/test_custom_ops.py | 2 +- test/test_decomp.py | 2 +- test/test_fake_tensor.py | 4 ++-- test/test_fx.py | 10 +++++----- test/test_jit.py | 4 ++-- test/test_nn.py | 14 +++++++------- test/test_serialization.py | 2 +- test/test_static_runtime.py | 2 +- test/test_weak.py | 8 ++++---- torch/_dynamo/compiled_autograd.py | 2 +- torch/_dynamo/eval_frame.py | 2 +- torch/_dynamo/functional_export.py | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/_dynamo/symbolic_convert.py | 2 +- torch/_dynamo/utils.py | 4 ++-- torch/_dynamo/variables/builder.py | 4 ++-- torch/_dynamo/variables/builtin.py | 2 +- torch/_dynamo/variables/constant.py | 4 ++-- torch/_dynamo/variables/dicts.py | 10 ++++------ torch/_dynamo/variables/functions.py | 2 +- torch/_dynamo/variables/higher_order_ops.py | 10 +++++----- torch/_dynamo/variables/iter.py | 6 +++--- torch/_dynamo/variables/nn_module.py | 2 +- 24 files changed, 51 insertions(+), 53 deletions(-) diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 221502ae3190a..77acfbe3472b0 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -205,7 +205,7 @@ def foo(self, arg): self.assertEqual(all_info["foo"]["info"], info) # example of how to turn the 'get_inputs_function_name' into the actual list of bundled inputs - for func_name in all_info.keys(): + for func_name in all_info: input_func_name = all_info[func_name]["get_inputs_function_name"][0] func_to_run = getattr(loaded, input_func_name) self.assertEqual(func_to_run(), samples) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 5898f5a346bac..bcc9c377e5049 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -1227,7 +1227,7 @@ def foo_impl(x): from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY - for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): + for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY: # Smoke test: should not raise error custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( foo_impl diff --git a/test/test_decomp.py b/test/test_decomp.py index f5c791c8cbe88..85522ad7e1820 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -1338,7 +1338,7 @@ def test_aten_core_operators(self): # operators, which never appear in AOTAutograd's graph so are never used. useful_decomps = { op - for op in decomposition_table.keys() + for op in decomposition_table if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op) } core_decomps = torch._decomp.core_aten_decompositions().keys() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 692a37b193d5e..de8cbbe8d6ff1 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1862,7 +1862,7 @@ def _read_tensor_and_check(key, sd_loaded, all_bytes, device): for k in sd: _read_tensor_and_check(k, sd_loaded, all_bytes, "cuda") - for k in sd.keys(): + for k in sd: sd[k] = sd[k].to("cuda") with TemporaryFileName() as f, torch.serialization.safe_globals([TwoTensor]): @@ -2482,7 +2482,7 @@ def fn(x, y): def count_invoke_subgraph_keys(): invoke_subgraph_keys = 0 - for cache_key in FakeTensorMode.cache.keys(): + for cache_key in FakeTensorMode.cache: if isinstance(cache_key.key[0], torch._ops.HigherOrderOperator): invoke_subgraph_keys += 1 return invoke_subgraph_keys diff --git a/test/test_fx.py b/test/test_fx.py index 3ad21e64c8ce2..0b177a96ae0b0 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2209,7 +2209,7 @@ def test_interpreter_gc_values(self): interp = Interpreter(symbolic_trace(rn18)) inp = torch.rand(5, 3, 224, 224) out = interp.run(inp) - env_key_names = {n.name for n in interp.env.keys()} + env_key_names = {n.name for n in interp.env} self.assertEqual(env_key_names, {"output"}) def test_interpreter_default_args(self): @@ -3471,12 +3471,12 @@ def module_exists(gm: GraphModule, path: str) -> bool: def parameter_exists(gm: GraphModule, path: str) -> bool: return any(path == name for name, _ in gm.named_parameters()) and any( - path == name for name in gm.state_dict().keys() + path == name for name in gm.state_dict() ) def buffer_exists(gm: GraphModule, path: str) -> bool: return any(path == name for name, _ in gm.named_buffers()) and any( - path == name for name in gm.state_dict().keys() + path == name for name in gm.state_dict() ) # Test that we added the "dropout" submodule @@ -5060,13 +5060,13 @@ def setUpClass(cls): def no(*args, **kwargs): return False - for name in cls.TO_PATCH.keys(): + for name in cls.TO_PATCH: cls.TO_PATCH[name] = getattr(torch.nn.functional, name) setattr(torch.nn.functional, name, no) @classmethod def tearDownClass(cls): - for name in cls.TO_PATCH.keys(): + for name in cls.TO_PATCH: setattr(torch.nn.functional, name, cls.TO_PATCH[name]) diff --git a/test/test_jit.py b/test/test_jit.py index 99d7e711da305..ebb3a8e85c733 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9459,7 +9459,7 @@ def forward(self, input): return self.mods(input) m = M() - self.assertTrue('mods.conv.weight' in m.state_dict().keys()) + self.assertTrue('mods.conv.weight' in m.state_dict()) def test_script_sequential_multi_output_fail(self): class Sub(torch.jit.ScriptModule): @@ -11954,7 +11954,7 @@ def test_dict_keys_values(x): # type: (Dict[str, int]) -> Tuple[str, int] key_str = "" sum = 0 - for key in x.keys(): + for key in x: key_str += key for val in x.values(): sum += val diff --git a/test/test_nn.py b/test/test_nn.py index bedb4b22a01bd..e119902717b1d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1038,13 +1038,13 @@ def check(): self.assertIs(modules[k1], module_dict[k2]) for k in module_dict: self.assertIs(module_dict[k], modules[k]) - for k in module_dict.keys(): + for k in module_dict: self.assertIs(module_dict[k], modules[k]) for k, v in module_dict.items(): self.assertIs(modules[k], v) for k1, m2 in zip(modules, module_dict.values()): self.assertIs(modules[k1], m2) - for k in modules.keys(): + for k in modules: self.assertTrue(k in module_dict) check() @@ -1245,13 +1245,13 @@ def check(): self.assertIs(parameters[k1], parameter_dict[k2]) for k in parameter_dict: self.assertIs(parameter_dict[k], parameters[k]) - for k in parameter_dict.keys(): + for k in parameter_dict: self.assertIs(parameter_dict[k], parameters[k]) for k, v in parameter_dict.items(): self.assertIs(v, parameters[k]) for k1, m2 in zip(parameters, parameter_dict.values()): self.assertIs(parameters[k1], m2) - for k in parameters.keys(): + for k in parameters: self.assertTrue(k in parameter_dict) check() @@ -2356,7 +2356,7 @@ def test_state_dict(self): self.assertIn('bn.running_var', state_dict) self.assertIn('bn.running_mean', state_dict) self.assertIn('bn.num_batches_tracked', state_dict) - self.assertFalse(any(k.startswith('empty') for k in state_dict.keys())) + self.assertFalse(any(k.startswith('empty') for k in state_dict)) for k, v in state_dict.items(): param = net for component in k.split('.'): @@ -4123,7 +4123,7 @@ def make_noncontig(tensor): def compare_cpu_gpu(outputs_cpu, outputs_gpu): self.assertEqual(list(outputs_cpu.keys()), list(outputs_gpu.keys())) - for key in outputs_cpu.keys(): + for key in outputs_cpu: if key != 'weights': self.assertEqual(outputs_cpu[key], outputs_gpu[key], atol=5e-5, rtol=0, msg=key) @@ -7281,7 +7281,7 @@ def test_convert_sync_batchnorm(self): self.assertEqual(children[1].__class__, torch.nn.InstanceNorm1d) for layer, converted_layer in zip(comp_module.children(), sync_bn_module.children()): - for key in layer.state_dict().keys(): + for key in layer.state_dict(): self.assertEqual(layer.state_dict()[key].device, converted_layer.state_dict()[key].device) self.assertEqual(layer.state_dict()[key], converted_layer.state_dict()[key]) diff --git a/test/test_serialization.py b/test/test_serialization.py index dcf67fe3ccf14..20f74b6dc6a21 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -948,7 +948,7 @@ def test_skip_data_load(self): with safe_globals([TwoTensor]), skip_data(): sd_loaded = torch.load(f) self.assertNotEqual(sd_loaded, sd) - for k in sd_loaded.keys(): + for k in sd_loaded: sd_loaded[k] = sd_loaded[k].zero_() self.assertEqual(sd_loaded, sd_zeroed) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index f7efe9b929168..310962311b396 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -556,7 +556,7 @@ def test_fusion_outputs(self): torch._C._fuse_to_static_module(og.graph) assert "StaticSubgraph" in str(og.graph) o_test = og(a, b, b, c) - for i in o_ref.keys(): + for i in o_ref: torch.testing.assert_close(o_ref[i], o_test[i]) def test_create_object(self): diff --git a/test/test_weak.py b/test/test_weak.py index e46268852c983..28fa1436b5c23 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -84,12 +84,12 @@ def check_update(self, klass, dict): weakdict = klass() weakdict.update(dict) self.assertEqual(len(weakdict), len(dict)) - for k in weakdict.keys(): + for k in weakdict: self.assertIn(k, dict, "mysterious new key appeared in weak dict") v = dict.get(k) self.assertIs(v, weakdict[k]) self.assertIs(v, weakdict.get(k)) - for k in dict.keys(): + for k in dict: self.assertIn(k, weakdict, "original key disappeared in weak dict") v = dict[k] self.assertIs(v, weakdict[k]) @@ -328,7 +328,7 @@ def test_write(self): for key, value in self.reference.items(): p[key] = value self.assertEqual(p[key], value) - for key in self.reference.keys(): + for key in self.reference: del p[key] self.assertRaises(KeyError, lambda: p[key]) p = self._empty_mapping() @@ -662,7 +662,7 @@ def test_write(self): for key, value in self.reference.items(): p[key] = value self.assertEqual(p[key], value) - for key in self.reference.keys(): + for key in self.reference: del p[key] self.assertRaises(KeyError, lambda: p[key]) p = self._empty_mapping() diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index cace23af20565..20fe8771a7899 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -995,7 +995,7 @@ def remove_unused_sizes(self) -> set[int]: sizes_node = next(it) assert sizes_node.name == "sizes" - for getitem_node in sizes_node.users.keys(): + for getitem_node in sizes_node.users: assert getitem_node.target is operator.getitem if getitem_node.users: used_sizes.append(getitem_node) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 222647eeae9ab..9ff4ae46523c3 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2138,7 +2138,7 @@ def fakify_with_ambient( # Error if we have any constraints on static values - for k in shape_env.var_to_range.keys(): + for k in shape_env.var_to_range: if isinstance(k, sympy.Integer): constraint_violation_error = ConstraintViolationError( f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 23b02e69a5640..6258131248e38 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -435,7 +435,7 @@ def _suggest_or_raise_constraint_violation( # Error if we have any constraints on static values - for k in shape_env.var_to_range.keys(): + for k in shape_env.var_to_range: if isinstance(k, sympy.Integer): constraint_violation_error = ConstraintViolationError( f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index f393b4a269d89..1c6661e53a777 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -3255,7 +3255,7 @@ def create_node( def remove_node(self, node: fx.Node) -> None: if len(node.users) > 0: user_graph_nodes: list[torch.fx.Node] = [] - for user in node.users.keys(): + for user in node.users: # For the case where user.graph == self.graph, that is a real bug and will raise # properly. if user.graph != self.graph: diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 83e3edf5d8d6d..f7903b198bcc4 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2809,7 +2809,7 @@ def create_resume( reads = livevars_analysis(self.instructions, resume_inst) all_argnames = tuple( k - for k in self.symbolic_locals.keys() + for k in self.symbolic_locals if k in reads and k not in self.cell_and_freevars() ) argnames_null_set = set(meta.locals_null_keys) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 4bff421c7d385..ca56d9785febe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2843,7 +2843,7 @@ def key_is_id( def key_to_id(value: Any) -> list[Any]: - return [id(k) if key_is_id(k) else k for k in value.keys()] + return [id(k) if key_is_id(k) else k for k in value] def const_repr(x: Any, *, local: Any) -> str: @@ -3263,7 +3263,7 @@ def get_multiplier() -> float: log_error=log_error, use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor, ) - for key in ref.__dict__.keys() + for key in ref.__dict__ ) else: raise RuntimeError(f"unsupported type: {type(ref).__name__}") diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 0c74055973bf8..9733bc946c308 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -586,7 +586,7 @@ def wrap_mapping_proxy(self, value): # This might be suboptimal compared to dict guards. But mappingproxy is # not very common, so its ok to guard on all keys. self.install_guards(GuardBuilder.MAPPING_KEYS_CHECK) - all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) + all_const = all(ConstantVariable.is_literal(k) for k in value) if not all_const: unimplemented_v2( @@ -732,7 +732,7 @@ def from_tensor(): return self.tx.output.side_effects.track_object_existing(value, result) elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): self.install_guards(GuardBuilder.TYPE_MATCH) - all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) + all_const = all(ConstantVariable.is_literal(k) for k in value) # For all_const, we don't have to guard on anything yet. We guard on # keys lazily by adding a dict_getitem entry for each accessed key. diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 0f198377605ec..2ac7bc7fe60b4 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2150,7 +2150,7 @@ def call_custom_dict_fromkeys( ) if isinstance(arg, dict): - arg_list = [ConstantVariable.create(k) for k in arg.keys()] + arg_list = [ConstantVariable.create(k) for k in arg] return DictVariableType( # pyrefly: ignore [bad-argument-type] dict.fromkeys(arg_list, value), diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 1793f5c10844e..86d3d87e1f8be 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -192,7 +192,7 @@ def call_method( except NotImplementedError: return super().call_method(tx, name, args, kwargs) - if isinstance(self.value, str) and name in str.__dict__.keys(): + if isinstance(self.value, str) and name in str.__dict__: method = getattr(self.value, name) try: return ConstantVariable.create(method(*const_args, **const_kwargs)) @@ -233,7 +233,7 @@ def call_method( elif isinstance(self.value, bytes) and name == "decode": method = getattr(self.value, name) return ConstantVariable.create(method(*const_args, **const_kwargs)) - elif type(self.value) is complex and name in complex.__dict__.keys(): + elif type(self.value) is complex and name in complex.__dict__: method = getattr(self.value, name) try: return ConstantVariable.create(method(*const_args, **const_kwargs)) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index fb212c3326222..1c3a7011d4cfc 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -802,7 +802,7 @@ def call_method( def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: self.install_dict_keys_match_guard() - return [x.vt for x in self.items.keys()] + return [x.vt for x in self.items] def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -1027,7 +1027,7 @@ def debug_repr(self) -> str: if not self.items: return "set()" else: - return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" @property def set_items(self) -> set["ConstDictVariable._HashableTracker"]: @@ -1307,7 +1307,7 @@ def debug_repr(self) -> str: if not self.items: return "frozenset()" else: - return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" + return "{" + ",".join(k.vt.debug_repr() for k in self.items) + "}" @property def set_items(self) -> set["ConstDictVariable._HashableTracker"]: @@ -1372,9 +1372,7 @@ def debug_repr(self) -> str: return "dict_keys([])" else: return ( - "dict_keys([" - + ",".join(k.vt.debug_repr() for k in self.items.keys()) - + "])" + "dict_keys([" + ",".join(k.vt.debug_repr() for k in self.items) + "])" ) def install_dict_keys_match_guard(self) -> None: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 0752a413fce6e..5fd903e7bbfdf 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -988,7 +988,7 @@ def __init__( self.generator_cls = generator_cls def __getattr__(self, name): - if name in self.__class__.__dict__.keys(): + if name in self.__class__.__dict__: return getattr(self, name) return getattr(self.vt, name) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 3f084cd00f59c..15f88f45bf7c5 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2374,7 +2374,7 @@ def create_wrapped_node( # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, # all the arguments are lifted. - lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + lifted_args = tuple(arg for arg in body_lifted_freevars) proxy_args = (body_node,) + lifted_args example_value = pytree.tree_map_only( @@ -2660,7 +2660,7 @@ def _call_function( # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, # all the arguments are lifted. - lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + lifted_args = tuple(body_lifted_freevars.keys()) p_args = (body_node, lifted_args, {}) p_kwargs = {} @@ -2777,7 +2777,7 @@ def _call_function( strict_mode_node = make_attr(tx, strict_mode_name) p_args = ( strict_mode_node, - tuple(arg for arg in ret_lifted_freevars.keys()), + tuple(ret_lifted_freevars.keys()), ) flat_example_value = pytree.tree_map_only( @@ -3115,7 +3115,7 @@ def create_scalar(): # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph. # We then need to create proxies for this + the inputs. - lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) + lifted_args = tuple(arg for arg in body_lifted_freevars) proxy_args = (body_node, lifted_args) @@ -3443,7 +3443,7 @@ def is_strict_for(v: VariableTracker): # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph, # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output. fwd_proxy_of_bwd_freevars = [] - for k in bwd_freevars.keys(): + for k in bwd_freevars: if k in fwd_freevars: fwd_proxy_of_bwd_freevars.append(fwd_freevars[k]) else: diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index bdb37da3ccce1..ecad58920d7c2 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -63,7 +63,7 @@ def call_function( # See also: module `torch._dynamo.polyfills.itertools` if self.value is itertools.product: - if any(kw != "repeat" for kw in kwargs.keys()): + if any(kw != "repeat" for kw in kwargs): unimplemented_v2( gb_type="Unsupported kwargs for itertools.product", context=f"call_function {self} {args} {kwargs}", @@ -72,7 +72,7 @@ def call_function( hints=[*graph_break_hints.USER_ERROR], ) - if "repeat" in kwargs.keys(): + if "repeat" in kwargs: r = kwargs["repeat"].as_python_constant() else: r = 1 @@ -103,7 +103,7 @@ def call_function( mutation_type=ValueMutationNew(), ) elif self.value is itertools.groupby: - if any(kw != "key" for kw in kwargs.keys()): + if any(kw != "key" for kw in kwargs): unimplemented_v2( gb_type="Unsupported kwargs for itertools.groupby", context=f"call_function {self} {args} {kwargs}", diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 794fdf607220a..f6ba0b1a5ffbc 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -763,7 +763,7 @@ def gen_source(source, name): f"{len(args)} args and {len(kwargs)} kwargs", ) result = [] - for name in module.keys(): + for name in module: result.append(ConstantVariable.create(name)) return ListIteratorVariable(result, mutation_type=ValueMutationNew()) elif name == "values": From 5bda7afa05f62550f9c162afb4e55dac08733f5c Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 11:45:31 +0000 Subject: [PATCH 198/651] [9/N] Fix unused loop variables in tests (#167290) This PR fixes unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167290 Approved by: https://github.com/mlazos --- test/distributed/fsdp/test_fsdp_misc.py | 1 + test/inductor/test_torchinductor.py | 13 +++++++------ test/test_binary_ufuncs.py | 21 ++++++++++----------- test/test_mps.py | 2 +- torch/_inductor/codegen/simd.py | 7 ++++--- torch/_inductor/comms.py | 4 ++-- torch/_inductor/distributed_autotune.py | 2 +- 7 files changed, 26 insertions(+), 24 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 99a1c3ad1707c..83a03489ada95 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -479,6 +479,7 @@ def test_fsdp_optimizer_overlap(self): for (n, p), (n_prev, p_prev) in zip( fsdp_overlap.named_parameters(), fsdp_overlap_prev_params ): + self.assertEqual(n, n_prev) self.assertNotEqual( p, p_prev, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fe9fa5a5e3a4c..801f983ae9080 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2938,17 +2938,18 @@ def fn2(x, y): self.common(fn2, (torch.randn(size1), torch.randn(size2))) def test_views2(self): - def fn1(x): - return (x.view(size2) + 1,) - - def fn2(x): - return ((x * 2).view(size2) + 1,) - for size1, size2 in [ ([2, 2, 2, 2], [4, -1]), ([10, 1, 10, 1, 10], [-1, 100]), ([10 * 5, 20], [10, -1, 20]), ]: + + def fn1(x): + return (x.view(size2) + 1,) + + def fn2(x): + return ((x * 2).view(size2) + 1,) + self.common(fn1, (torch.randn(size1),)) self.common(fn2, (torch.randn(size1),)) diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 406242964d1c9..56a4202cded3f 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1876,20 +1876,19 @@ def _scalar_helper(python_op, torch_op): expected = python_op(a, b) - for op in (operator.truediv, torch.true_divide): - actual_scalar = torch_op(a, b) + actual_scalar = torch_op(a, b) - a_t = torch.tensor(a, device=device) - b_t = torch.tensor(b, device=device) + a_t = torch.tensor(a, device=device) + b_t = torch.tensor(b, device=device) - actual_tensor = torch_op(a_t, b_t) - actual_first_tensor = torch_op(a_t, b) - actual_second_tensor = torch_op(a, b_t) + actual_tensor = torch_op(a_t, b_t) + actual_first_tensor = torch_op(a_t, b) + actual_second_tensor = torch_op(a, b_t) - self.assertEqual(actual_scalar, expected) - self.assertEqual(actual_tensor.item(), expected) - self.assertEqual(actual_first_tensor, actual_tensor) - self.assertEqual(actual_second_tensor, actual_tensor) + self.assertEqual(actual_scalar, expected) + self.assertEqual(actual_tensor.item(), expected) + self.assertEqual(actual_first_tensor, actual_tensor) + self.assertEqual(actual_second_tensor, actual_tensor) _scalar_helper(operator.truediv, operator.truediv) _scalar_helper(operator.truediv, torch.true_divide) diff --git a/test/test_mps.py b/test/test_mps.py index 867429432cfe0..765ec3c52e036 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4910,7 +4910,7 @@ def helper(shape): input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) - for i, cpu_x in enumerate(input_xs): + for cpu_x in input_xs: x = cpu_x.detach().clone().to('mps') y = torch.any(x) ref_y = torch.any(cpu_x) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index f3b5de1f0ab46..2ad02ca97a54b 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -2304,13 +2304,12 @@ def generate_combo_kernel_code( for node_group in partitions: if len(node_group) == 0: continue - fused_node_lists = [node.get_nodes() for node in node_group] kernel = ComboKernel( enable_autotune=enable_autotune, mixed_sizes=mixed_sizes, ) - for pn, nodes in zip(node_group, fused_node_lists): + for pn in node_group: self.codegen_node_schedule_with_kernel( node_schedule_map[pn][0], kernel.create_sub_kernel(subkernel_map[pn]), @@ -2565,8 +2564,10 @@ def get_nd_tilings( all_var_ranges = [*dep.ranges.items()] pointwise_vars_numel = sympy.S.One sizevars = V.graph.sizevars - for pointwise_end_idx, (var, numel) in enumerate(all_var_ranges): + pointwise_end_idx = 0 + for idx, (_var, numel) in enumerate(all_var_ranges): pointwise_vars_numel *= numel + pointwise_end_idx = idx if sizevars.statically_known_geq( pointwise_vars_numel, pointwise_numel ): diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 29efcb4a44493..ba2571f266244 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -305,7 +305,7 @@ def coll_exposed_communication_time( continue if contains_wait(snode): has_wait_for_collectives_found = False - for coll in collectives_found: + for _coll in collectives_found: if is_corresponding_collective_wait(collective_snode, snode): has_wait_for_collectives_found = True break @@ -1891,7 +1891,7 @@ def _sink_waits_iterative_internal( ) # pyrefly: ignore[no-matching-overload] -max(0, info.comm_time - info.comp_time - c_runtime) - for gc, (gc_comm_time, gc_comp_time) in group_colls.items(): + for gc_comm_time, gc_comp_time in group_colls.values(): exposed_delta += max(0, gc_comm_time - gc_comp_time) - max( 0, gc_comm_time - gc_comp_time + c_runtime ) diff --git a/torch/_inductor/distributed_autotune.py b/torch/_inductor/distributed_autotune.py index af2d5bb9e9f11..ec53d25efcd5b 100644 --- a/torch/_inductor/distributed_autotune.py +++ b/torch/_inductor/distributed_autotune.py @@ -204,7 +204,7 @@ def _sync(autotune_results: list[_SerializedChoice]) -> Sequence[_SerializedChoi choices_by_index: list[_SerializedChoice] = [None] * node_count # type: ignore[list-item] check_count = 0 - for i, other_results in enumerate(all_states): + for other_results in all_states: for choice in other_results: assert isinstance(choice, _SerializedChoice) assert choices_by_index[choice.index] is None From aded2ebb90fe8ed6835fec65bf3c122548abb66f Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 13:50:33 +0000 Subject: [PATCH 199/651] [3/N] Add return types of Python functions (#167287) This PR adds return types to some Python functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167287 Approved by: https://github.com/mlazos --- torch/jit/__init__.py | 2 +- torch/jit/_builtins.py | 4 +- torch/jit/_check.py | 6 +- torch/jit/_decomposition_utils.py | 2 +- torch/jit/_decompositions.py | 2 +- torch/jit/_freeze.py | 2 +- torch/jit/_fuser.py | 2 +- torch/jit/_ir_utils.py | 2 +- torch/jit/_monkeytype_config.py | 6 +- torch/jit/_recursive.py | 18 ++--- torch/jit/_serialization.py | 4 +- torch/jit/_state.py | 14 ++-- torch/jit/annotations.py | 10 +-- torch/jit/frontend.py | 12 ++-- torch/jit/mobile/__init__.py | 2 +- torch/jit/supported_ops.py | 2 +- torch/jit/unsupported_tensor_ops.py | 2 +- torch/nested/_internal/nested_int.py | 2 +- torch/nested/_internal/nested_tensor.py | 4 +- torch/nested/_internal/ops.py | 10 +-- torch/nested/_internal/sdpa.py | 4 +- torch/onnx/_internal/exporter/_building.py | 2 +- .../_internal/exporter/_capture_strategies.py | 2 +- torch/onnx/_internal/exporter/_core.py | 2 +- .../onnx/_internal/exporter/_onnx_program.py | 4 +- torch/onnx/_internal/exporter/_reporting.py | 4 +- torch/onnx/_internal/exporter/_schemas.py | 2 +- torch/onnx/_internal/exporter/_tensors.py | 2 +- torch/onnx/_internal/fx/_pass.py | 4 +- .../_internal/fx/passes/type_promotion.py | 24 +++---- .../torchscript_exporter/symbolic_opset9.py | 14 ++-- .../_internal/torchscript_exporter/utils.py | 34 +++++----- .../torchscript_exporter/verification.py | 10 +-- torch/onnx/errors.py | 4 +- torch/utils/_config_module.py | 4 +- torch/utils/_sympy/reference.py | 14 ++-- torch/utils/benchmark/utils/timer.py | 2 +- torch/utils/checkpoint.py | 35 +++++----- torch/utils/data/_utils/fetch.py | 8 ++- torch/utils/data/dataloader.py | 30 ++++---- .../data/datapipes/dataframe/dataframes.py | 68 +++++++++---------- torch/utils/data/datapipes/iter/grouping.py | 10 +-- torch/utils/data/datapipes/iter/sharding.py | 11 +-- torch/utils/data/datapipes/utils/common.py | 24 ++++--- 44 files changed, 218 insertions(+), 208 deletions(-) diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9decaeecc86d0..c277d2e3ab71a 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -279,7 +279,7 @@ def _hide_source_ranges() -> Iterator[None]: torch._C.Graph.set_global_print_source_ranges(old_enable_source_ranges) # type: ignore[attr-defined] -def enable_onednn_fusion(enabled: bool): +def enable_onednn_fusion(enabled: bool) -> None: """Enable or disables onednn JIT fusion based on the parameter `enabled`.""" torch._C._jit_set_llga_enabled(enabled) diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 2aa2fae3fde51..f61fadf375bf1 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -162,7 +162,7 @@ def _get_builtin_table(): return _builtin_table _builtin_table = {} - def register_all(mod): + def register_all(mod) -> None: for name in dir(mod): v = getattr(mod, name) if ( @@ -196,7 +196,7 @@ def register_all(mod): return _builtin_table -def _register_builtin(fn, op): +def _register_builtin(fn, op) -> None: _get_builtin_table()[id(fn)] = op diff --git a/torch/jit/_check.py b/torch/jit/_check.py index 261a2ce554b5f..36440769f063f 100644 --- a/torch/jit/_check.py +++ b/torch/jit/_check.py @@ -116,7 +116,7 @@ def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool: return True - def visit_Assign(self, node): + def visit_Assign(self, node) -> None: """Store assignment state when assigning to a Call Node. If we're visiting a Call Node (the right-hand side of an @@ -139,7 +139,7 @@ def visit_Assign(self, node): self.generic_visit(node) self.visiting_class_level_ann = False - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node) -> None: """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method. It checks if it conforms to our attribute annotation rules.""" @@ -194,7 +194,7 @@ def visit_AnnAssign(self, node): stacklevel=2, ) - def visit_Call(self, node): + def visit_Call(self, node) -> None: """Determine if a Call node is 'torch.jit.annotate' in __init__. Visit a Call node in an ``nn.Module``'s ``__init__`` diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index 3a4b4ceff2cf3..48f24f0d85d6c 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -3,7 +3,7 @@ from torch._ops import OpOverload, OpOverloadPacket -def _register_decomposition(op: OpOverload, graph: torch._C.Graph): +def _register_decomposition(op: OpOverload, graph: torch._C.Graph) -> None: assert not isinstance(op, OpOverloadPacket), ( f"Must pass specific op overload, not overload packet, found {op}" ) diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index c855606865adb..bb628f82a8ef0 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -20,7 +20,7 @@ _P = ParamSpec("_P") -def check_decomposition_has_type_annotations(f): +def check_decomposition_has_type_annotations(f) -> None: inspect_empty = inspect._empty # type: ignore[attr-defined] sig = inspect.signature(f) for param in sig.parameters.values(): diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py index b61a2dd6207d1..7fd8ac1a772bb 100644 --- a/torch/jit/_freeze.py +++ b/torch/jit/_freeze.py @@ -125,7 +125,7 @@ def forward(self, input): def run_frozen_optimizations( mod, optimize_numerics: bool = True, preserved_methods: Optional[list[str]] = None -): +) -> None: r""" Run a series of optimizations looking for patterns that occur in frozen graphs. diff --git a/torch/jit/_fuser.py b/torch/jit/_fuser.py index dc5dd80362971..a6a2f1cce67e8 100644 --- a/torch/jit/_fuser.py +++ b/torch/jit/_fuser.py @@ -83,7 +83,7 @@ def fuser(name): last_executed_optimized_graph = torch._C._last_executed_optimized_graph -def _get_differentiable_graph_node(node, diff_node): +def _get_differentiable_graph_node(node, diff_node) -> None: if node.kind() == "prim::DifferentiableGraph": diff_node.append(node) else: diff --git a/torch/jit/_ir_utils.py b/torch/jit/_ir_utils.py index d7f03ee3bc868..7a775717de07f 100644 --- a/torch/jit/_ir_utils.py +++ b/torch/jit/_ir_utils.py @@ -9,7 +9,7 @@ def __init__( self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block], - ): + ) -> None: self.insert_point = insert_point self.g = insert_point_graph self.guard = None diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index a15d140dc7944..0f348590ea397 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -85,7 +85,7 @@ def get_qualified_name(func): class JitTypeTraceStoreLogger(CallTraceStoreLogger): """A JitTypeCallTraceLogger that stores logged traces in a CallTraceStore.""" - def __init__(self, store: CallTraceStore): + def __init__(self, store: CallTraceStore) -> None: super().__init__(store) def log(self, trace: CallTrace) -> None: @@ -100,7 +100,7 @@ def __init__(self) -> None: # value is list of all CallTrace self.trace_records: dict[str, list] = defaultdict(list) - def add(self, traces: Iterable[CallTrace]): + def add(self, traces: Iterable[CallTrace]) -> None: for t in traces: qualified_name = get_qualified_name(t.func) self.trace_records[qualified_name].append(t) @@ -145,7 +145,7 @@ def get_args_types(self, qualified_name: str) -> dict: return self.consolidate_types(qualified_name) class JitTypeTraceConfig(monkeytype.config.Config): - def __init__(self, s: JitTypeTraceStore): + def __init__(self, s: JitTypeTraceStore) -> None: super().__init__() self.s = s diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 343871b1f94a2..75355cbd4b8e0 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -152,7 +152,7 @@ def _get_valid_constant(attr, v, owner_type): class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): - def __init__(self, source, filename, file_lineno, leading_whitespace_len): + def __init__(self, source, filename, file_lineno, leading_whitespace_len) -> None: super().__init__(source, filename, file_lineno, leading_whitespace_len) @@ -454,7 +454,7 @@ def get_or_create_concrete_type(self, nn_module): def create_methods_and_properties_from_stubs( concrete_type, method_stubs, property_stubs -): +) -> None: method_defs = [m.def_ for m in method_stubs] method_rcbs = [m.resolution_callback for m in method_stubs] method_defaults = [get_default_args(m.original_method) for m in method_stubs] @@ -467,7 +467,7 @@ def create_methods_and_properties_from_stubs( ) -def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): +def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) -> None: hook_defs = [h.def_ for h in hook_stubs] hook_rcbs = [h.resolution_callback for h in hook_stubs] @@ -571,7 +571,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) ignored_properties = jit_ignored_properties(nn_module) - def init_fn(script_module): + def init_fn(script_module) -> None: # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. for name in concrete_type.get_attributes(): @@ -725,7 +725,7 @@ def script_model_defines_attr(script_model, attr): return script_attr != default_attr -def add_python_attr_to_scripted_model(script_model, orig, attr): +def add_python_attr_to_scripted_model(script_model, orig, attr) -> None: if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): setattr(script_model, attr, getattr(orig, attr)) @@ -777,7 +777,7 @@ def get_overload_name_mapping(overload_info): return overload_name_mappings -def _check_no_signature(func): +def _check_no_signature(func) -> None: signature = torch.jit.annotations.get_signature( func, None, fake_range(), inspect.ismethod(func) ) @@ -807,7 +807,7 @@ def make_stubs_for_overloads(overload_info): return overload_stubs -def check_module_initialized(mod): +def check_module_initialized(mod) -> None: assert isinstance(mod, torch.nn.Module) if not hasattr(mod, "_parameters"): raise RuntimeError( @@ -1002,7 +1002,7 @@ def wrap_cpp_class(cpp_class): def wrap_cpp_module(cpp_module): """Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules.""" - def init_fn(script_module): + def init_fn(script_module) -> None: for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): setattr(script_module, name, wrap_cpp_module(cpp_module)) script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type( @@ -1037,7 +1037,7 @@ def lazy_bind(concrete_type, unbound_method): """ def lazy_binding_method(cpp_module, *args): - def init_fn(script_module): + def init_fn(script_module) -> None: orig_class = concrete_type.py_class # Copy @ignored/@unused methods from the original module to the new one. diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 02004a1122013..9cffe107c25be 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -18,7 +18,7 @@ from torch.serialization import validate_cuda_device -def save(m, f, _extra_files=None): +def save(m, f, _extra_files=None) -> None: r""" Save an offline version of this module for use in a separate process. @@ -213,7 +213,7 @@ def jit_module_from_flatbuffer(f): return wrap_cpp_module(torch._C._load_jit_module_from_bytes(f.read())) -def save_jit_module_to_flatbuffer(m, f, _extra_files=None): +def save_jit_module_to_flatbuffer(m, f, _extra_files=None) -> None: r""" Save an offline version of this module for use in a separate process. diff --git a/torch/jit/_state.py b/torch/jit/_state.py index f48dd80a0b36f..2ebbee553ab11 100644 --- a/torch/jit/_state.py +++ b/torch/jit/_state.py @@ -41,18 +41,18 @@ def parse_env(self, name, default, true_message, false_message): return False raise ValueError(f"Unknown setting of {name}. Try using 0 or 1.") - def __bool__(self): + def __bool__(self) -> bool: return self.enabled _enabled = EnabledProxy() -def disable(): +def disable() -> None: _enabled.enabled = False -def enable(): +def enable() -> None: _enabled.enabled = True @@ -67,7 +67,7 @@ def enable(): _name_to_pyclass: dict[str, type[Any]] = {} -def _add_script_class(python_class, script_class): +def _add_script_class(python_class, script_class) -> None: _script_classes[python_class] = script_class _name_to_pyclass[script_class.qualified_name()] = python_class @@ -83,7 +83,7 @@ def _get_python_class(qualified_name): return _name_to_pyclass.get(qualified_name) -def _clear_class_state(): +def _clear_class_state() -> None: _script_classes.clear() _name_to_pyclass.clear() @@ -108,7 +108,7 @@ def _try_get_jit_cached_overloads(key): return None -def _set_jit_overload_cache(key, compiled_fns): +def _set_jit_overload_cache(key, compiled_fns) -> None: _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns] @@ -122,7 +122,7 @@ def _try_get_jit_cached_function(key): return None -def _set_jit_function_cache(key, value): +def _set_jit_function_cache(key, value) -> None: # only free functions currently supported assert isinstance(value, torch.jit.ScriptFunction) _jit_caching_layer[key] = value.qualified_name diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index f1ede0bd2450d..cf1be7bac8f9d 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -68,7 +68,7 @@ class Module: - def __init__(self, name, members): + def __init__(self, name, members) -> None: self.name = name self.members = members @@ -95,7 +95,7 @@ class EvalEnv: "Await": _Await, } - def __init__(self, rcb): + def __init__(self, rcb) -> None: self.rcb = rcb if torch.distributed.rpc.is_available(): # pyrefly: ignore [unsupported-operation] @@ -178,7 +178,7 @@ def get_param_names(fn, n_args): return [str(i) for i in range(n_args)] -def check_fn(fn, loc): +def check_fn(fn, loc) -> None: # Make sure the function definition is not a class instantiation try: source = dedent("".join(get_source_lines_and_file(fn)[0])) @@ -368,7 +368,7 @@ def get_enum_value_type(e: type[enum.Enum], loc): return res -def is_tensor(ann): +def is_tensor(ann) -> bool: if issubclass(ann, torch.Tensor): return True @@ -397,7 +397,7 @@ def is_tensor(ann): return False -def _fake_rcb(inp): +def _fake_rcb(inp) -> None: return None diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 5b1db800a7838..9f686a5a626f3 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -147,7 +147,7 @@ def is_reserved_name(name): class FrontendError(Exception): - def __init__(self, source_range, msg): + def __init__(self, source_range, msg) -> None: self.source_range = source_range self.msg = msg @@ -155,7 +155,7 @@ def __init__(self, source_range, msg): # call stack when the FrontendError was raised self.error_report = torch._C.ErrorReport(self.source_range) - def __str__(self): + def __str__(self) -> str: return self.msg + self.error_report.what().lstrip() @@ -164,7 +164,7 @@ class NotSupportedError(FrontendError): class UnsupportedNodeError(NotSupportedError): - def __init__(self, ctx, offending_node, reason=""): + def __init__(self, ctx, offending_node, reason="") -> None: # If we don't have a specific token, we default to length of 1 node_type = type(offending_node) range_len = len(node_start_tokens.get(node_type, " ")) @@ -229,7 +229,7 @@ def get_class_properties(cls, self_name): def get_class_assigns(ctx, cls_ast): assigns = [] - def maybe_build_assign(builder, entry): + def maybe_build_assign(builder, entry) -> None: nonlocal assigns try: assigns.append(builder(ctx, entry)) @@ -385,7 +385,7 @@ def _forward(self): # TODO: more robust handling of recognizing ignore context manager -def is_torch_jit_ignore_context_manager(stmt): +def is_torch_jit_ignore_context_manager(stmt) -> bool: # checks if the statement is torch.jit.ignore context manager if isinstance(stmt.items[0].context_expr, ast.Call): # extract torch part @@ -535,7 +535,7 @@ def process_ins_outs(args): outputs.append(OutputType(var_name, var_ann)) return inputs, outputs - def create_unique_name_ext(ctx, stmt): + def create_unique_name_ext(ctx, stmt) -> str: # extension will be based on the full path filename plus # the line number of original context manager fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename) diff --git a/torch/jit/mobile/__init__.py b/torch/jit/mobile/__init__.py index 32c2f5b321ee3..608d1c2f7798d 100644 --- a/torch/jit/mobile/__init__.py +++ b/torch/jit/mobile/__init__.py @@ -56,7 +56,7 @@ def _load_for_lite_interpreter(f, map_location=None): class LiteScriptModule: - def __init__(self, cpp_module): + def __init__(self, cpp_module) -> None: self._c = cpp_module super().__init__() diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 6cbca07966da5..8a258280ea352 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -57,7 +57,7 @@ def _emit_schema(mod, name, schema, arg_start=0, padding=4): def _get_tensor_ops(): - def is_tensor_method(schema): + def is_tensor_method(schema) -> bool: if len(schema.arguments) == 0: return False self = schema.arguments[0] diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index 162e4c5320685..12bca0fce337d 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -5,7 +5,7 @@ import torch.jit -def execWrapper(code, glob, loc): +def execWrapper(code, glob, loc) -> None: exec(code, glob, loc) diff --git a/torch/nested/_internal/nested_int.py b/torch/nested/_internal/nested_int.py index 59090b331d501..b347258b5f463 100644 --- a/torch/nested/_internal/nested_int.py +++ b/torch/nested/_internal/nested_int.py @@ -35,7 +35,7 @@ def _ge(lhs: Any, rhs: Any) -> bool: class NestedIntNode: - def __init__(self, t_id: int, coeff: int): + def __init__(self, t_id: int, coeff: int) -> None: self.t_id = t_id self.coeff = coeff diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 8d446a7bd518d..cf4e3fecf4e6c 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -131,7 +131,7 @@ def __new__( return r - def __init__(self, values, offsets, *, lengths=None, **kwargs): + def __init__(self, values, offsets, *, lengths=None, **kwargs) -> None: super().__init__() self._values = values @@ -243,7 +243,7 @@ def _is_contiguous_or_false(self): self._values, memory_format=torch.contiguous_format ) - def __repr__(self): # type: ignore[override] + def __repr__(self) -> str: # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( f", requires_grad={self.requires_grad}" if self.requires_grad else "" diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 69c324ab726ec..e2126ca5632f9 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -400,7 +400,7 @@ def jagged_torch_function(func, *args, **kwargs): # Handle flatten() here because it's CompositeImplicit. if func.__name__ == "flatten": - def _flatten_sig(input, start_dim=0, end_dim=-1): + def _flatten_sig(input, start_dim=0, end_dim=-1) -> None: pass _, new_kwargs = normalize_function( # type: ignore[misc] @@ -466,7 +466,7 @@ def _flatten_sig(input, start_dim=0, end_dim=-1): # Handle nested-specific input validation for CompositeImplicit rms_norm if func.__name__ == "rms_norm": - def _rms_norm_sig(input, normalized_shape, weight=None, eps=None): + def _rms_norm_sig(input, normalized_shape, weight=None, eps=None) -> None: pass _, new_kwargs = normalize_function( # type: ignore[misc] @@ -532,7 +532,7 @@ def prim_layout_default(func, *args, **kwargs): [torch.ops.aten.size.default], "self: jt_all", ) -def tensor_attr_unsupported_getter(func, *args, **kwargs): +def tensor_attr_unsupported_getter(func, *args, **kwargs) -> None: if func is torch.ops.aten.size.default: raise RuntimeError( "NestedTensor does not support directly calling torch.ops.aten.size; " @@ -1138,7 +1138,7 @@ def unbind_int(func, *args, **kwargs): lengths = inp.lengths() ragged_idx = inp._ragged_idx - def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None): + def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None) -> None: # This torch._check are needed for torch.compile # symbolic shapes processing. # offsets and lengths are symbolic variables during compilation, @@ -2615,7 +2615,7 @@ def _nested_select_backward_default(func, *args, **kwargs): @register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any") -def record_stream_default(func, *args, **kwargs): +def record_stream_default(func, *args, **kwargs) -> None: inp = args[0] stream = args[1] # ensure all components live until stream computation completes diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index fe385dc5c766f..631a5991879aa 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -31,7 +31,7 @@ def _validate_sdpa_input( dropout_p=0.0, is_causal=False, scale=None, -): +) -> None: if ( not isinstance(query, NestedTensor) or not isinstance(key, NestedTensor) @@ -364,7 +364,7 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> tuple[torch.Tensor, in return cumulative_seqlen, max_seqlen, n_elem -def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor): +def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor) -> bool: # This function checks if a nested tensor is valid for # use with the flash-attention and efficient_attention kernels without # needing to call contiguous on the nested tensor input. diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 2dbcf8f083877..4536e33087eb8 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -537,7 +537,7 @@ class OpRecorder(evaluator.Evaluator): def __init__( self, opset: onnxscript.values.Opset, constant_farm: dict[Any, ir.Value] - ): + ) -> None: self.nodes: list[ir.Node] = [] self.opset = opset self.functions: dict[ diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 63421ff5bb947..8d1f04a8a80a7 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -92,7 +92,7 @@ def __init__( dump: bool = False, artifacts_dir: str | os.PathLike = ".", timestamp: str | None = None, - ): + ) -> None: """Initialize the strategy. Args: diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index b618943c3f21b..f1f1ac6c67e40 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -109,7 +109,7 @@ def torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType: class TorchTensor(ir.Tensor): - def __init__(self, tensor: torch.Tensor, name: str | None = None): + def __init__(self, tensor: torch.Tensor, name: str | None = None) -> None: # Pass the tensor as the raw data to ir.Tensor's constructor if tensor.dtype == torch.float4_e2m1fn_x2: # Change the shape to the unpacked shape diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 942638598047f..b2d4101fdc9a1 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -211,7 +211,7 @@ class ONNXProgram: def __init__( self, model: ir.Model, exported_program: torch.export.ExportedProgram | None - ): + ) -> None: """Initialize the ONNX program with the specified model and exported program. Args: model: The ONNX model. @@ -327,7 +327,7 @@ def save( include_initializers: bool = True, keep_initializers_as_inputs: bool = False, external_data: bool | None = None, - ): + ) -> None: """Save the ONNX model to the specified destination. When ``external_data`` is ``True`` or the model is larger than 2GB, diff --git a/torch/onnx/_internal/exporter/_reporting.py b/torch/onnx/_internal/exporter/_reporting.py index e2e02e089c5d1..dc9cabeb677c4 100644 --- a/torch/onnx/_internal/exporter/_reporting.py +++ b/torch/onnx/_internal/exporter/_reporting.py @@ -149,7 +149,7 @@ def create_torch_export_error_report( *, export_status: ExportStatus, profile_result: str | None, -): +) -> None: with open(filename, "w", encoding="utf-8") as f: f.write("# PyTorch ONNX Conversion Error Report\n\n") f.write(_format_export_status(export_status)) @@ -175,7 +175,7 @@ def create_onnx_export_report( model: ir.Model | None = None, registry: _registration.ONNXRegistry | None = None, verification_result: str | None = None, -): +) -> None: with open(filename, "w", encoding="utf-8") as f: f.write("# PyTorch ONNX Conversion Report\n\n") f.write(_format_export_status(export_status)) diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py index 0ed3791c46fc7..89991b030509b 100644 --- a/torch/onnx/_internal/exporter/_schemas.py +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -21,7 +21,7 @@ # A special value to indicate that the default value is not specified class _Empty: - def __repr__(self): + def __repr__(self) -> str: return "_EMPTY_DEFAULT" diff --git a/torch/onnx/_internal/exporter/_tensors.py b/torch/onnx/_internal/exporter/_tensors.py index 2a6c74120d568..8f0706bf98638 100644 --- a/torch/onnx/_internal/exporter/_tensors.py +++ b/torch/onnx/_internal/exporter/_tensors.py @@ -18,7 +18,7 @@ def __init__( type: ir.TypeProtocol | None = None, doc_string: str | None = None, const_value: ir.TensorProtocol | None = None, - ): + ) -> None: super().__init__( name=name, shape=shape, diff --git a/torch/onnx/_internal/fx/_pass.py b/torch/onnx/_internal/fx/_pass.py index b1fad573f2902..95b7892fec4df 100644 --- a/torch/onnx/_internal/fx/_pass.py +++ b/torch/onnx/_internal/fx/_pass.py @@ -66,7 +66,7 @@ def _patch_difflib_sequence_matcher_init(): """ original_init = difflib.SequenceMatcher.__init__ - def patched_init(self, isjunk=None, a="", b="", autojunk=True): + def patched_init(self, isjunk=None, a="", b="", autojunk=True) -> None: original_init(self, isjunk, a, b, autojunk=False) difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment] @@ -192,7 +192,7 @@ class Transform(abc.ABC): def __init__( self, module: torch.fx.GraphModule, - ): + ) -> None: """Initialize the transform. Args: diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 0dea1aa15317e..3d4e919a3b2fb 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -63,7 +63,7 @@ class TypePromotionSnapshot: class TypePromotionRule(abc.ABC): """Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'.""" - def __init__(self, namespace: str, op_name: str): + def __init__(self, namespace: str, op_name: str) -> None: self.namespace = namespace self.op_name = op_name @@ -74,7 +74,7 @@ def __init__(self, namespace: str, op_name: str): def __hash__(self) -> int: ... @abc.abstractmethod - def __repr__(self): ... + def __repr__(self) -> str: ... @abc.abstractmethod def __eq__(self, other: object) -> bool: ... @@ -128,7 +128,7 @@ def __init__( promote_args_positions: Sequence[int], promote_kwargs_names: Sequence[str], promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND, - ): + ) -> None: """Constructs a TypePromotionRule for elementwise operators. Args: @@ -143,7 +143,7 @@ def __init__( self.promote_kwargs_names = promote_kwargs_names self.promotion_kind = promotion_kind - def __repr__(self): + def __repr__(self) -> str: return ( f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', " f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})" @@ -216,7 +216,7 @@ class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule): Rule depends on the value of the `rounding_mode` argument. """ - def __init__(self): + def __init__(self) -> None: super().__init__( "aten", "div", @@ -252,7 +252,7 @@ def __init__( namespace: str, op_name: str, promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND, - ): + ) -> None: """Constructs a TypePromotionRule for reduction operators. Args: @@ -263,7 +263,7 @@ def __init__( super().__init__(namespace, op_name) self.promotion_kind = promotion_kind - def __repr__(self): + def __repr__(self) -> str: return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})" # pyrefly: ignore [bad-override] @@ -311,7 +311,7 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule): The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool. """ - def __init__(self, op_name: str): + def __init__(self, op_name: str) -> None: super().__init__( "aten", op_name, @@ -1205,7 +1205,7 @@ def _parse_type_promotion_rule_from_refs_op( class TypePromotionTable: """Type promotion table for torch.ops.""" - def __init__(self): + def __init__(self) -> None: self._rule_table = {} for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET: self.add_rule(rule) @@ -1262,7 +1262,7 @@ class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode): op overload for a given op overload packet for different set of args and kwargs. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.traced_ops = [] @@ -1331,7 +1331,7 @@ def __init__( self, module: torch.fx.GraphModule, type_promotion_table: TypePromotionTable, - ): + ) -> None: super().__init__(module) self.type_promotion_table = type_promotion_table @@ -1603,7 +1603,7 @@ def __init__( self, module: torch.fx.GraphModule, type_promotion_table: TypePromotionTable | None = None, - ): + ) -> None: super().__init__(module) self.interpreter = _TypePromotionInterpreter( module, type_promotion_table or TypePromotionTable() diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py index e1b34469fbf20..7ae1c5a082e1c 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -810,7 +810,7 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = @_onnx_symbolic("aten::cumsum") @symbolic_helper.parse_args("v", "i", "none") -def cumsum(g: jit_utils.GraphContext, input, dim, dtype): +def cumsum(g: jit_utils.GraphContext, input, dim, dtype) -> None: symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) @@ -3332,7 +3332,9 @@ def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): @_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") -def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): +def _unique2( + g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts +) -> None: symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) @@ -6289,7 +6291,7 @@ def broadcast_tensors(g: jit_utils.GraphContext, self): @_onnx_symbolic("aten::is_pinned") -def is_pinned(g: jit_utils.GraphContext, self, device=None): +def is_pinned(g: jit_utils.GraphContext, self, device=None) -> None: # Unused by ONNX. return None @@ -6357,7 +6359,7 @@ def prim_layout(g: jit_utils.GraphContext, self): @_onnx_symbolic("prim::ListConstruct") -def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): +def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: return None @@ -6374,12 +6376,12 @@ def prim_list_unpack( @_onnx_symbolic("prim::TupleConstruct") -def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): +def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: return None @_onnx_symbolic("prim::Uninitialized") -def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): +def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: return None diff --git a/torch/onnx/_internal/torchscript_exporter/utils.py b/torch/onnx/_internal/torchscript_exporter/utils.py index d66962f690ea1..050b60c292684 100644 --- a/torch/onnx/_internal/torchscript_exporter/utils.py +++ b/torch/onnx/_internal/torchscript_exporter/utils.py @@ -571,7 +571,7 @@ def forward(self, x): return None -def _is_constant_tensor_list(node): +def _is_constant_tensor_list(node) -> bool | None: if node.kind() != "prim::Constant": return False output_type = node.output().type() @@ -585,7 +585,7 @@ def _is_constant_tensor_list(node): # get generated in constant prop. So we split them back into prim::ListConstructs -def _split_tensor_list_constants(g, block): +def _split_tensor_list_constants(g, block) -> None: for node in block.nodes(): for subblock in node.blocks(): _split_tensor_list_constants(g, subblock) @@ -722,7 +722,7 @@ def _optimize_graph( return graph -def warn_on_static_input_change(input_states): +def warn_on_static_input_change(input_states) -> None: """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph. We accept dictionaries and strings as ONNX inputs, but they should be only for @@ -932,7 +932,7 @@ def _get_param_count_list(method_graph, args_params): return param_count_list -def _check_flatten_did_not_remove(original, jit_flattened): +def _check_flatten_did_not_remove(original, jit_flattened) -> None: """torch.jit._flatten removes None. Check if it did so in this case.""" def flatten(x): @@ -1286,13 +1286,13 @@ def _setup_trace_module_map( model: torch.nn.Module | torch.jit.ScriptModule, export_modules_as_functions: bool | Collection[type[torch.nn.Module]], ) -> set[str]: - def __register_attribute_hook(): + def __register_attribute_hook() -> None: attr_name = "_onnx_attrs" - def _track_module_attributes_forward_pre_hook(module, input): + def _track_module_attributes_forward_pre_hook(module, input) -> None: setattr(module, attr_name, _get_module_attributes(module)) - def _track_module_attributes_forward_hook(module, input, output): + def _track_module_attributes_forward_hook(module, input, output) -> None: tracing_state = _C._get_tracing_state() if not tracing_state: return @@ -1359,7 +1359,7 @@ def _find_typename(v): return module_typenames -def _reset_trace_module_map(): +def _reset_trace_module_map() -> None: torch.jit._trace._trace_module_map = None _C._jit_pass_onnx_clear_scope_records() @@ -1388,7 +1388,7 @@ def _get_module_attributes(module): return attrs -def _trigger_symbolic_function_registration(): +def _trigger_symbolic_function_registration() -> None: """Trigger the registration of symbolic functions for all supported opsets.""" from torch.onnx._internal.torchscript_exporter import ( # noqa: F401 @@ -1599,7 +1599,7 @@ def _export( return torch_out -def _apply_friendly_debug_names(graph, params): +def _apply_friendly_debug_names(graph, params) -> None: for n in graph.nodes(): for v in n.inputs(): old_name = v.debugName() @@ -1611,8 +1611,8 @@ def _apply_friendly_debug_names(graph, params): params[new_name] = params.pop(old_name) -def _set_input_and_output_names(graph, input_names, output_names): - def set_names(node_list, name_list, descriptor): +def _set_input_and_output_names(graph, input_names, output_names) -> None: + def set_names(node_list, name_list, descriptor) -> None: if name_list is None: return if len(name_list) > len(node_list): @@ -1681,7 +1681,7 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int: def _should_aten_fallback( name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes -): +) -> bool: # For all builds, if domain=="aten" and operator_export_type==ONNX_ATEN, # an aten::ATen operator is created regardless of symbolics existence @@ -1822,7 +1822,7 @@ def _run_symbolic_function( raise -def _verify_custom_op_name(symbolic_name: str): +def _verify_custom_op_name(symbolic_name: str) -> None: if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name): raise errors.OnnxExporterError( f"Failed to register operator {symbolic_name}. " @@ -1842,7 +1842,7 @@ def register_custom_op_symbolic( symbolic_name: str, symbolic_fn: Callable, opset_version: int, -): +) -> None: """Registers a symbolic function for a custom operator. When the user registers symbolic for custom/contrib ops, @@ -1868,7 +1868,7 @@ def register_custom_op_symbolic( registration.custom_onnx_symbolic(symbolic_name, opset_version)(symbolic_fn) -def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): +def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int) -> None: """Unregisters ``symbolic_name``. See "Custom Operators" in the module documentation for an example usage. @@ -1886,7 +1886,7 @@ def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int): registration.registry.unregister(symbolic_name, opset_version) -def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): +def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) -> None: """Ensures dynamic axes argument is follows the expected format.""" if len(dynamic_axes) == 0: return diff --git a/torch/onnx/_internal/torchscript_exporter/verification.py b/torch/onnx/_internal/torchscript_exporter/verification.py index 32885d1f63774..33fa18a3fd472 100644 --- a/torch/onnx/_internal/torchscript_exporter/verification.py +++ b/torch/onnx/_internal/torchscript_exporter/verification.py @@ -209,7 +209,7 @@ def _compare_onnx_pytorch_outputs_in_np( onnx_outs: _OutputsType, pt_outs: _OutputsType, options: VerificationOptions, -): +) -> None: assert len(onnx_outs) == len(pt_outs), ( f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" ) @@ -261,7 +261,7 @@ def _compare_onnx_pytorch_outputs( onnx_outs: _OutputsType, pt_outs: Any, options: VerificationOptions, -): +) -> None: """ Compare ONNX and PyTorch outputs. @@ -383,7 +383,7 @@ def _compare_onnx_pytorch_model( input_kwargs: _InputKwargsType | None, additional_test_inputs: Sequence[_InputArgsType] | None, options: VerificationOptions, -): +) -> None: """Compare outputs from ONNX model runs with outputs from PyTorch model runs. Args: @@ -401,7 +401,7 @@ def _compare_onnx_pytorch_model( """ onnx_session = _onnx_backend_session(onnx_model_f, options.backend) - def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): + def compare_onnx_pytorch_model_with_input(input_args, input_kwargs) -> None: pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) # TODO: remove this and treat mutating model separately. See #77679 pt_model_copy = _try_clone_model(pt_model) @@ -443,7 +443,7 @@ def verify( use_external_data: bool = False, additional_test_inputs: Sequence[_InputArgsType] | None = None, options: VerificationOptions | None = None, -): +) -> None: """Verify model export to ONNX against original PyTorch model. .. deprecated:: 2.7 diff --git a/torch/onnx/errors.py b/torch/onnx/errors.py index d5483dc67e3b1..3645e01d7a7a2 100644 --- a/torch/onnx/errors.py +++ b/torch/onnx/errors.py @@ -30,7 +30,7 @@ class UnsupportedOperatorError(OnnxExporterError): # NOTE: This is legacy and is only used by the torchscript exporter # Clean up when the torchscript exporter is removed - def __init__(self, name: str, version: int, supported_version: int | None): + def __init__(self, name: str, version: int, supported_version: int | None) -> None: if supported_version is not None: msg = ( f"Exporting the operator '{name}' to ONNX opset version {version} " @@ -57,7 +57,7 @@ class SymbolicValueError(OnnxExporterError): # NOTE: This is legacy and is only used by the torchscript exporter # Clean up when the torchscript exporter is removed - def __init__(self, msg: str, value: _C.Value): + def __init__(self, msg: str, value: _C.Value) -> None: message = ( f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 33546eb01b6a0..ca298219560e8 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -299,7 +299,7 @@ class _ConfigEntry: hide: bool = False alias: Optional[str] = None - def __init__(self, config: _Config): + def __init__(self, config: _Config) -> None: self.default = config.default self.value_type = ( config.value_type if config.value_type is not None else type(self.default) @@ -792,7 +792,7 @@ class SubConfigProxy: `config.triton.cudagraphs` maps to _config["triton.cudagraphs"] """ - def __init__(self, config: object, prefix: str): + def __init__(self, config: object, prefix: str) -> None: # `super().__setattr__` to bypass custom `__setattr__` super().__setattr__("_config", config) super().__setattr__("_prefix", prefix) diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index c3a3878f3c8c1..e9b4a91429a4d 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import math import operator -from typing import Union +from typing import NoReturn, Union import sympy @@ -139,7 +139,7 @@ def floordiv(a, b): return FloorDiv(a, b) @staticmethod - def truncdiv(a, b): + def truncdiv(a, b) -> NoReturn: raise NotImplementedError("TODO: truncdiv") @staticmethod @@ -257,11 +257,11 @@ def to_dtype(x, dtype): raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod - def exp(x): + def exp(x) -> NoReturn: raise AssertionError("exp is not valid shape sympy expr") @staticmethod - def log(x): + def log(x) -> NoReturn: raise AssertionError("log is not valid shape sympy expr") @staticmethod @@ -448,7 +448,7 @@ def to_dtype(x, dtype): return _to_dtype(x, dtype) @staticmethod - def mod(x, y): + def mod(x, y) -> NoReturn: # TODO: https://github.com/pytorch/pytorch/pull/133654 raise NotImplementedError( "no C-style modulus operation available from frontend atm" @@ -484,7 +484,7 @@ def floordiv(a, b): return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor") @staticmethod - def truncdiv(a, b): + def truncdiv(a, b) -> NoReturn: raise NotImplementedError( "no C-style truncdiv operation available from frontend atm" ) @@ -575,7 +575,7 @@ def round_to_int(a, dtype): return torch.ops.aten.round.default(a) @staticmethod - def round_decimal(a, b): + def round_decimal(a, b) -> NoReturn: raise NotImplementedError( "round decimal doesn't support Tensor second argument atm" ) diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 3dc17edeb7964..09dbb4b5a0863 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -188,7 +188,7 @@ def __init__( env: Optional[str] = None, num_threads: int = 1, language: Union[Language, str] = Language.PYTHON, - ): + ) -> None: if not isinstance(stmt, str): raise ValueError("Currently only a `str` stmt is supported.") diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index d9802c06e9444..9b10c4d192d4e 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -14,6 +14,7 @@ from torch.utils._pytree import tree_map from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode from torch.utils._python_dispatch import TorchDispatchMode +from typing import NoReturn __all__ = [ "checkpoint", @@ -107,7 +108,7 @@ class DefaultDeviceType: _default_device_type = "cuda" @staticmethod - def set_device_type(device: str = "cuda"): + def set_device_type(device: str = "cuda") -> None: """ Set the default device type for checkpointing. @@ -130,7 +131,7 @@ def get_device_type() -> str: def _infer_device_type(*args): device_types = [] - def add_device_types(arg): + def add_device_types(arg) -> None: nonlocal device_types if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": device_types.append(arg.device.type) @@ -166,7 +167,7 @@ def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: # the conditionals short-circuit. fwd_device_ids = [] - def add_device_ids(arg): + def add_device_ids(arg) -> None: nonlocal fwd_device_ids if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: fwd_device_ids.append(arg.get_device()) @@ -601,7 +602,7 @@ def forward(input): return run_function(end + 1, len(functions) - 1, functions)(input) -def _internal_assert(cond): +def _internal_assert(cond) -> None: if not cond: raise AssertionError( "Something went unexpectedly wrong in activation checkpoint. " @@ -779,7 +780,7 @@ class _Handle: class _Holder: - def __init__(self): + def __init__(self) -> None: self.handles: Dict[int, Optional[_Handle]] = {} @@ -817,12 +818,12 @@ def get_args(saved_tensors): ctx.save_for_backward(*tensors) @staticmethod - def backward(ctx, *grad_outputs): + def backward(ctx, *grad_outputs) -> NoReturn: raise AssertionError("Did not expect to backward on this graph") class _CheckpointFrame: - def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): + def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn) -> None: self.recompute_fn = recompute_fn self.input_saver = None self.weak_holders: List[ReferenceType] = [] @@ -847,7 +848,7 @@ def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): self.forward_completed = False self.ignore_saved_mismatch = False - def check_recomputed_tensors_match(self, gid): + def check_recomputed_tensors_match(self, gid) -> None: if self.ignore_saved_mismatch: # TODO: we can probably make this check stricter by checking that # the metadata of the first tensors still match. @@ -999,7 +1000,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' class CaptureLogs: - def __init__(self): + def __init__(self) -> None: self.logs = None self.tbs = None @@ -1016,7 +1017,7 @@ def logging_mode(): capture_logs_fwd = CaptureLogs() capture_logs_recompute = CaptureLogs() - def unpack_error_cb(e: CheckpointError): + def unpack_error_cb(e: CheckpointError) -> NoReturn: def get_str_tb(label, capture_logs): out = "" total_len = len(capture_logs.logs) @@ -1071,7 +1072,7 @@ class _StopRecomputationError(Exception): class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): - def __init__(self, target_frame_ref: ReferenceType, gid: int): + def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None: def pack_hook(x): x = x.detach() if x.requires_grad else x target_frame = target_frame_ref() @@ -1132,7 +1133,7 @@ def _run_fn_with_dynamo_disabled(fn, *args, **kwargs): class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): - def __init__(self, frame): + def __init__(self, frame) -> None: def pack_hook(x): # See Rule 4 above holder = _Holder() @@ -1196,7 +1197,7 @@ def _is_compiling(func, args, kwargs): class _VersionWrapper: # Check that cached tensors are not mutated. - def __init__(self, val): + def __init__(self, val) -> None: self.val: Union[torch.Tensor, Any] = val self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None @@ -1251,7 +1252,7 @@ class SelectiveCheckpointContext: >>> context_fn=context_fn, >>> ) """ - def __init__(self, *, is_recompute): + def __init__(self, *, is_recompute) -> None: self.is_recompute = is_recompute @@ -1301,7 +1302,7 @@ def _policy_from_bool(b): class _CachingTorchDispatchMode(TorchDispatchMode): # Used together with _CachedTorchDispatchMode to implement SAC. - def __init__(self, policy_fn, storage): + def __init__(self, policy_fn, storage) -> None: self.policy_fn = policy_fn self.storage = storage @@ -1337,7 +1338,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): class _CachedTorchDispatchMode(TorchDispatchMode): # Used together with _CachedTorchDispatchMode to implement SAC. - def __init__(self, policy_fn, storage, allow_cache_entry_mutation): + def __init__(self, policy_fn, storage, allow_cache_entry_mutation) -> None: self.policy_fn = policy_fn self.storage = storage self.allow_cache_entry_mutation = allow_cache_entry_mutation @@ -1542,7 +1543,7 @@ def _checkpoint_without_reentrant_generator( had_device_in_fwd = True fwd_devices, fwd_device_states = get_device_states(*args) - def recompute_fn(*inputs): + def recompute_fn(*inputs) -> None: kwargs, *args = inputs # This will be called later during recomputation. This wrapping enables # the necessary global state to be captured. diff --git a/torch/utils/data/_utils/fetch.py b/torch/utils/data/_utils/fetch.py index 3fa6c49404f67..9bcd0ec5b3073 100644 --- a/torch/utils/data/_utils/fetch.py +++ b/torch/utils/data/_utils/fetch.py @@ -4,20 +4,22 @@ This logic is shared in both single- and multi-processing data loading. """ +from typing import NoReturn + class _BaseDatasetFetcher: - def __init__(self, dataset, auto_collation, collate_fn, drop_last): + def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None: self.dataset = dataset self.auto_collation = auto_collation self.collate_fn = collate_fn self.drop_last = drop_last - def fetch(self, possibly_batched_index): + def fetch(self, possibly_batched_index) -> NoReturn: raise NotImplementedError class _IterableDatasetFetcher(_BaseDatasetFetcher): - def __init__(self, dataset, auto_collation, collate_fn, drop_last): + def __init__(self, dataset, auto_collation, collate_fn, drop_last) -> None: super().__init__(dataset, auto_collation, collate_fn, drop_last) self.dataset_iter = iter(dataset) self.ended = False diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 19400eb4a21a7..467e8c655d2bc 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -17,7 +17,7 @@ import threading import warnings from collections.abc import Callable -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import Self import torch @@ -108,7 +108,7 @@ def _get_distributed_settings(): return 1, 0 -def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): +def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id) -> None: global_worker_id = worker_id info = torch.utils.data.get_worker_info() if info is None: @@ -436,7 +436,7 @@ def multiprocessing_context(self): return self.__multiprocessing_context @multiprocessing_context.setter - def multiprocessing_context(self, multiprocessing_context): + def multiprocessing_context(self, multiprocessing_context) -> None: if multiprocessing_context is not None: if self.num_workers > 0: if isinstance(multiprocessing_context, str): @@ -468,7 +468,7 @@ def multiprocessing_context(self, multiprocessing_context): self.__multiprocessing_context = multiprocessing_context - def __setattr__(self, attr, val): + def __setattr__(self, attr, val) -> None: if self.__initialized and attr in ( "batch_size", "batch_sampler", @@ -546,7 +546,7 @@ def __len__(self) -> int: else: return len(self._index_sampler) - def check_worker_number_rationality(self): + def check_worker_number_rationality(self) -> None: # This function check whether the dataloader's worker number is rational based on # current system's resource. Current rule is that if the number of workers this # Dataloader will create is bigger than the number of logical cpus that is allowed to @@ -714,7 +714,7 @@ def __init__(self, loader: DataLoader) -> None: def __iter__(self) -> Self: return self - def _reset(self, loader, first_iter=False): + def _reset(self, loader, first_iter=False) -> None: self._sampler_iter = iter(self._index_sampler) self._num_yielded = 0 self._IterableDataset_len_called = loader._IterableDataset_len_called @@ -729,7 +729,7 @@ def _reset(self, loader, first_iter=False): def _next_index(self): return next(self._sampler_iter) # may raise StopIteration - def _next_data(self): + def _next_data(self) -> NoReturn: raise NotImplementedError def __next__(self) -> Any: @@ -770,7 +770,7 @@ def __getstate__(self): class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): - def __init__(self, loader): + def __init__(self, loader) -> None: super().__init__(loader) if self._timeout != 0: raise AssertionError("_SingleProcessDataLoaderIter requires timeout == 0") @@ -1113,7 +1113,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # processing indices already in `index_queue` if we are already shutting # down. - def __init__(self, loader): + def __init__(self, loader) -> None: super().__init__(loader) self._prefetch_factor = loader.prefetch_factor @@ -1235,7 +1235,7 @@ def __init__(self, loader): self._worker_pids_set = True self._reset(loader, first_iter=True) - def _reset(self, loader, first_iter=False): + def _reset(self, loader, first_iter=False) -> None: super()._reset(loader, first_iter) self._send_idx = 0 # idx of the next task to be sent to workers self._rcvd_idx = 0 # idx of the next task to be returned in __next__ @@ -1529,7 +1529,7 @@ def _next_data(self): self._rcvd_idx += 1 return self._process_data(data, worker_id) - def _try_put_index(self): + def _try_put_index(self) -> None: max_tasks = self._prefetch_factor * self._num_workers if self._tasks_outstanding >= max_tasks: raise AssertionError( @@ -1568,7 +1568,7 @@ def _process_data(self, data, worker_idx): data.reraise() return data - def _mark_worker_as_unavailable(self, worker_id, shutdown=False): + def _mark_worker_as_unavailable(self, worker_id, shutdown=False) -> None: # Mark a worker as having finished its work e.g., due to # exhausting an `IterableDataset`. This should be used only when this # `_MultiProcessingDataLoaderIter` is going to continue running. @@ -1604,7 +1604,7 @@ def _mark_worker_as_unavailable(self, worker_id, shutdown=False): "_workers_done_event state does not match shutdown flag" ) - def _shutdown_workers(self): + def _shutdown_workers(self) -> None: # Called when shutting down this `_MultiProcessingDataLoaderIter`. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on # the logic of this function. @@ -1678,12 +1678,12 @@ def _shutdown_workers(self): # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` @staticmethod - def _clean_up_worker(w): + def _clean_up_worker(w) -> None: try: w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) finally: if w.is_alive(): w.terminate() - def __del__(self): + def __del__(self) -> None: self._shutdown_workers() diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index 8908721bccd77..e8b03ff3b2afa 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, Optional +from typing import Any, NoReturn, Optional from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.dataframe.structures import DataChunkDF @@ -33,7 +33,7 @@ ] -def disable_capture(): +def disable_capture() -> None: CaptureControl.disabled = True @@ -42,7 +42,7 @@ class CaptureControl: class DataFrameTracedOps(DFIterDataPipe): - def __init__(self, source_datapipe, output_var): + def __init__(self, source_datapipe, output_var) -> None: self.source_datapipe = source_datapipe self.output_var = output_var @@ -72,10 +72,10 @@ def __iter__(self): class Capture: # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures - def __init__(self, schema_df=None): + def __init__(self, schema_df=None) -> None: self.ctx = {"operations": [], "variables": [], "schema_df": schema_df} - def __str__(self): + def __str__(self) -> str: return self._ops_str() def _ops_str(self): @@ -113,7 +113,7 @@ def __getattr__(self, attrname): def __getitem__(self, key): return CaptureGetItem(self, key, ctx=self.ctx) - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: # pyrefly: ignore [missing-attribute] self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) @@ -147,7 +147,7 @@ def _is_context_empty(self): # pyrefly: ignore [bad-argument-type] return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 - def apply_ops_2(self, dataframe): + def apply_ops_2(self, dataframe) -> None: # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) # pyrefly: ignore [unsupported-operation] self.ctx["variables"][0].calculated_value = dataframe @@ -190,7 +190,7 @@ def __call__(self, *args, **kwargs): class CaptureF(Capture): - def __init__(self, ctx=None, **kwargs): + def __init__(self, ctx=None, **kwargs) -> None: if ctx is None: self.ctx = {"operations": [], "variables": []} else: @@ -199,7 +199,7 @@ def __init__(self, ctx=None, **kwargs): class CaptureA(CaptureF): - def __str__(self): + def __str__(self) -> str: return f"{self.kwargs['name']}" def execute(self): @@ -208,7 +208,7 @@ def execute(self): class CaptureLikeMock: - def __init__(self, name): + def __init__(self, name) -> None: import unittest.mock as mock # TODO(VitalyFedyunin): Do not use private function here, copy own implementation instead. @@ -227,7 +227,7 @@ def __exit__(self, *exc_info): class CaptureCall(Capture): - def __init__(self, callable, ctx=None, **kwargs): + def __init__(self, callable, ctx=None, **kwargs) -> None: if ctx is None: self.ctx = {"operations": [], "variables": []} else: @@ -235,7 +235,7 @@ def __init__(self, callable, ctx=None, **kwargs): self.kwargs = kwargs self.callable = callable - def __str__(self): + def __str__(self) -> str: return "{callable}({args},{kwargs})".format( callable=self.callable, **self.kwargs ) @@ -253,12 +253,12 @@ def execute(self): class CaptureVariableAssign(CaptureF): - def __str__(self): + def __str__(self) -> str: variable = self.kwargs["variable"] value = self.kwargs["value"] return f"{variable} = {value}" - def execute(self): + def execute(self) -> None: self.kwargs["variable"].calculated_value = self.kwargs["value"].execute() @@ -266,7 +266,7 @@ class CaptureVariable(Capture): # TODO(VitalyFedyunin): This should be atomic and thread safe names_idx = 0 - def __init__(self, value, ctx): + def __init__(self, value, ctx) -> None: if CaptureControl.disabled: raise RuntimeError("Attempting to create capture variable with capture off") self.ctx = ctx @@ -275,7 +275,7 @@ def __init__(self, value, ctx): CaptureVariable.names_idx += 1 self.ctx["variables"].append(self) - def __str__(self): + def __str__(self) -> str: return self.name def execute(self): @@ -292,12 +292,12 @@ def apply_ops(self, dataframe): class CaptureGetItem(Capture): - def __init__(self, left, key, ctx): + def __init__(self, left, key, ctx) -> None: self.ctx = ctx self.left = left self.key = key - def __str__(self): + def __str__(self) -> str: return f"{self.left}[{get_val(self.key)}]" def execute(self): @@ -306,28 +306,28 @@ def execute(self): class CaptureSetItem(Capture): - def __init__(self, left, key, value, ctx): + def __init__(self, left, key, value, ctx) -> None: self.ctx = ctx self.left = left self.key = key self.value = value - def __str__(self): + def __str__(self) -> str: return f"{self.left}[{get_val(self.key)}] = {self.value}" - def execute(self): + def execute(self) -> None: left = self.left.execute() value = self.value.execute() left[self.key] = value class CaptureAdd(Capture): - def __init__(self, left, right, ctx): + def __init__(self, left, right, ctx) -> None: self.ctx = ctx self.left = left self.right = right - def __str__(self): + def __str__(self) -> str: return f"{self.left} + {self.right}" def execute(self): @@ -335,12 +335,12 @@ def execute(self): class CaptureMul(Capture): - def __init__(self, left, right, ctx): + def __init__(self, left, right, ctx) -> None: self.ctx = ctx self.left = left self.right = right - def __str__(self): + def __str__(self) -> str: return f"{self.left} * {self.right}" def execute(self): @@ -348,12 +348,12 @@ def execute(self): class CaptureSub(Capture): - def __init__(self, left, right, ctx): + def __init__(self, left, right, ctx) -> None: self.ctx = ctx self.left = left self.right = right - def __str__(self): + def __str__(self) -> str: return f"{self.left} - {self.right}" def execute(self): @@ -361,12 +361,12 @@ def execute(self): class CaptureGetAttr(Capture): - def __init__(self, src, name, ctx): + def __init__(self, src, name, ctx) -> None: self.ctx = ctx self.src = src self.name = name - def __str__(self): + def __str__(self) -> str: return f"{self.src}.{self.name}" def execute(self): @@ -384,7 +384,7 @@ def get_val(capture): class CaptureInitial(CaptureVariable): - def __init__(self, schema_df=None): + def __init__(self, schema_df=None) -> None: # pyrefly: ignore [bad-assignment] new_ctx: dict[str, list[Any]] = { "operations": [], @@ -441,7 +441,7 @@ def shuffle(self, *args, **kwargs): def filter(self, *args, **kwargs): return self._dataframes_filter(*args, **kwargs) - def collate(self, *args, **kwargs): + def collate(self, *args, **kwargs) -> NoReturn: raise RuntimeError("Can't collate unbatched DataFrames stream") def __getattr__(self, attrname): # ? @@ -458,13 +458,13 @@ class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: i # TODO(VitalyFedyunin): Must implement all special functions of datapipes - def set_shuffle_settings(self, *args, **kwargs): + def set_shuffle_settings(self, *args, **kwargs) -> None: pass - def is_shardable(self): + def is_shardable(self) -> bool: return False - def __init__(self, source_datapipe, schema_df=None): + def __init__(self, source_datapipe, schema_df=None) -> None: self.source_datapipe = source_datapipe if schema_df is None: schema_df = next(iter(self.source_datapipe)) diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index a289bdb5e0949..16ae0965f3cff 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from collections import defaultdict from collections.abc import Callable, Iterator, Sized -from typing import Any, Optional, TypeVar +from typing import Any, NoReturn, Optional, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe @@ -18,7 +18,7 @@ _T_co = TypeVar("_T_co", covariant=True) -def __getattr__(name: str): +def __getattr__(name: str) -> NoReturn: raise AttributeError(f"module {__name__} has no attribute {name}") @@ -110,7 +110,7 @@ class UnBatcherIterDataPipe(IterDataPipe): [0, 1, 2, 3, 4, 5, 6] """ - def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1): + def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1) -> None: self.datapipe = datapipe self.unbatch_level = unbatch_level @@ -202,7 +202,7 @@ def __init__( group_size: Optional[int] = None, guaranteed_group_size: Optional[int] = None, drop_remaining: bool = False, - ): + ) -> None: _check_unpickable_fn(group_key_fn) # pyrefly: ignore [invalid-type-var] self.datapipe = datapipe @@ -322,5 +322,5 @@ def __setstate__(self, state): self.curr_buffer_size = 0 self.buffer_elements = defaultdict(list) - def __del__(self): + def __del__(self) -> None: self.buffer_elements.clear() diff --git a/torch/utils/data/datapipes/iter/sharding.py b/torch/utils/data/datapipes/iter/sharding.py index 0e381c87a4a58..494ea0106a041 100644 --- a/torch/utils/data/datapipes/iter/sharding.py +++ b/torch/utils/data/datapipes/iter/sharding.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from collections.abc import Sized from enum import IntEnum +from typing import NoReturn from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -24,7 +25,7 @@ def apply_sharding( num_of_instances: int, instance_id: int, sharding_group: SHARDING_PRIORITIES, - ): + ) -> NoReturn: raise NotImplementedError @@ -40,7 +41,9 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe): source_datapipe: Iterable DataPipe that will be sharded """ - def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter=None): + def __init__( + self, source_datapipe: IterDataPipe, sharding_group_filter=None + ) -> None: self.source_datapipe = source_datapipe self.sharding_group_filter = sharding_group_filter self.groups: dict[int, tuple[int, int]] = {} @@ -68,7 +71,7 @@ def apply_sharding( self.groups[sharding_group] = (num_of_instances, instance_id) self._update_num_of_instances() - def _update_num_of_instances(self): + def _update_num_of_instances(self) -> None: sorted_sharding_groups = [ self.groups[key] for key in sorted(self.groups.keys()) @@ -89,7 +92,7 @@ def __iter__(self): if i % self.num_of_instances == self.instance_id: yield item - def __len__(self): + def __len__(self) -> int: if isinstance(self.source_datapipe, Sized): return len(self.source_datapipe) // self.num_of_instances + ( 1 diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 003ca568fcaf6..7f27c2f37fc93 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Iterable from io import IOBase -from typing import Any, Optional, Union +from typing import Any, NoReturn, Optional, Union from torch.utils._import_utils import dill_available @@ -25,7 +25,9 @@ DILL_AVAILABLE = dill_available() -def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]): +def validate_input_col( + fn: Callable, input_col: Optional[Union[int, tuple, list]] +) -> None: """ Check that function used in a callable datapipe works with the input column. @@ -131,7 +133,7 @@ def _is_local_fn(fn): return False -def _check_unpickable_fn(fn: Callable): +def _check_unpickable_fn(fn: Callable) -> None: """ Check function is pickable or not. @@ -186,7 +188,7 @@ def get_file_pathnames_from_root( non_deterministic: bool = False, ) -> Iterable[str]: # print out an error message and raise the error out - def onerror(err: OSError): + def onerror(err: OSError) -> NoReturn: warnings.warn(err.filename + " : " + err.strerror, stacklevel=2) raise err @@ -235,7 +237,7 @@ def get_file_binaries_from_pathnames( yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) -def validate_pathname_binary_tuple(data: tuple[str, IOBase]): +def validate_pathname_binary_tuple(data: tuple[str, IOBase]) -> None: if not isinstance(data, tuple): raise TypeError( f"pathname binary data should be tuple type, but it is type {type(data)}" @@ -326,7 +328,7 @@ class StreamWrapper: session_streams: dict[Any, int] = {} debug_unclosed_streams: bool = False - def __init__(self, file_obj, parent_stream=None, name=None): + def __init__(self, file_obj, parent_stream=None, name=None) -> None: self.file_obj = file_obj self.child_counter = 0 self.parent_stream = parent_stream @@ -344,7 +346,7 @@ def __init__(self, file_obj, parent_stream=None, name=None): StreamWrapper.session_streams[self] = 1 @classmethod - def close_streams(cls, v, depth=0): + def close_streams(cls, v, depth=0) -> None: """Traverse structure and attempts to close all found StreamWrappers on best effort basis.""" if depth > 10: return @@ -363,7 +365,7 @@ def __getattr__(self, name): file_obj = self.__dict__["file_obj"] return getattr(file_obj, name) - def close(self, *args, **kwargs): + def close(self, *args, **kwargs) -> None: if self.closed: return if StreamWrapper.debug_unclosed_streams: @@ -381,7 +383,7 @@ def close(self, *args, **kwargs): pass self.closed = True - def autoclose(self): + def autoclose(self) -> None: """Automatically close stream when all child streams are closed or if there are none.""" self.close_on_last_child = True if self.child_counter == 0: @@ -392,7 +394,7 @@ def __dir__(self): attrs += dir(self.file_obj) return list(set(attrs)) - def __del__(self): + def __del__(self) -> None: if not self.closed: self.close() @@ -402,7 +404,7 @@ def __iter__(self): def __next__(self): return next(self.file_obj) - def __repr__(self): + def __repr__(self) -> str: if self.name is None: return f"StreamWrapper<{self.file_obj!r}>" else: From edd611f3b0655aa366276a47e07ab54c1d8b9d60 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Fri, 7 Nov 2025 14:05:16 +0000 Subject: [PATCH 200/651] [CI] Upgrade Ubuntu 24.04 for XPU CI tests (#162475) As the title Pull Request resolved: https://github.com/pytorch/pytorch/pull/162475 Approved by: https://github.com/EikanWang, https://github.com/atalman --- .ci/docker/build.sh | 4 +-- .ci/docker/common/install_xpu.sh | 25 +++++++++---------- .github/workflows/docker-builds.yml | 4 +-- .../inductor-perf-test-nightly-xpu.yml | 8 +++--- .github/workflows/pull.yml | 8 +++--- .github/workflows/xpu.yml | 20 +++++++-------- .../ATen/native/mkldnn/xpu/detail/QConv.cpp | 2 +- caffe2/CMakeLists.txt | 1 + 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 7d55884fbe431..5609b9e30dc2b 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -207,9 +207,9 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks) + pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 VISION=yes XPU_VERSION=2025.2 NINJA_VERSION=1.9.0 diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index 0b150872f93ce..22b7af890c1f6 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -9,7 +9,7 @@ set -xe function install_ubuntu() { . /etc/os-release - if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then + if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then echo "Ubuntu version ${VERSION_CODENAME} not supported" exit fi @@ -35,25 +35,24 @@ function install_ubuntu() { # The xpu-smi packages apt-get install -y flex bison xpu-smi - if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then - # Compute and Media Runtimes + # Compute and Media Runtimes + if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then apt-get install -y \ - intel-opencl-icd intel-level-zero-gpu level-zero \ - intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \ - libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ + intel-opencl-icd libze-intel-gpu1 libze1 \ + intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ + libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ - mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo - # Development Packages - apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev - else # rolling driver + mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc + else # jammy apt-get install -y \ intel-opencl-icd libze-intel-gpu1 libze1 \ intel-media-va-driver-non-free libmfx-gen1 libvpl2 \ libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc - apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev fi + # Development Packages + apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev # Install Intel Support Packages apt-get install -y ${XPU_PACKAGES} @@ -66,7 +65,7 @@ function install_ubuntu() { function install_rhel() { . /etc/os-release if [[ "${ID}" == "rhel" ]]; then - if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then + if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then echo "RHEL version ${VERSION_ID} not supported" exit fi @@ -147,7 +146,7 @@ function install_sles() { XPU_DRIVER_VERSION="" if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then # Use GPU driver LTS releases - XPU_DRIVER_VERSION="/lts/2350" + XPU_DRIVER_VERSION="/lts/2523" fi # Default use Intel® oneAPI Deep Learning Essentials 2025.1 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 0aa176cd1c676..6d3a5c321a1eb 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -68,8 +68,8 @@ jobs: pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-xpu-n-1-py3, - pytorch-linux-jammy-xpu-n-py3, - pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, + pytorch-linux-noble-xpu-n-py3, + pytorch-linux-noble-xpu-n-py3-inductor-benchmarks, pytorch-linux-jammy-py3-clang18-asan, pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, diff --git a/.github/workflows/inductor-perf-test-nightly-xpu.yml b/.github/workflows/inductor-perf-test-nightly-xpu.yml index c2db8c310e368..28b10996bf38a 100644 --- a/.github/workflows/inductor-perf-test-nightly-xpu.yml +++ b/.github/workflows/inductor-perf-test-nightly-xpu.yml @@ -83,8 +83,8 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-xpu-n-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks + build-environment: linux-noble-xpu-n-py3.10 + docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks runner: linux.c7i.12xlarge test-matrix: | { include: [ @@ -117,7 +117,7 @@ jobs: uses: ./.github/workflows/_xpu-test.yml needs: xpu-n-py3_10-inductor-benchmark-build with: - build-environment: linux-jammy-xpu-n-py3.10 + build-environment: linux-noble-xpu-n-py3.10 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} @@ -137,7 +137,7 @@ jobs: uses: ./.github/workflows/_xpu-test.yml needs: xpu-n-py3_10-inductor-benchmark-build with: - build-environment: linux-jammy-xpu-n-py3.10 + build-environment: linux-noble-xpu-n-py3.10 dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e3af55e736503..e5fd10c70db61 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -342,16 +342,16 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }} secrets: inherit - linux-jammy-xpu-n-py3_10-build: - name: linux-jammy-xpu-n-py3.10 + linux-noble-xpu-n-py3_10-build: + name: linux-noble-xpu-n-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: # This should sync with the build in xpu.yml but xpu uses a larger runner # sync-tag: linux-xpu-n-build runner_prefix: ${{ needs.get-label-type.outputs.label-type }} - build-environment: linux-jammy-xpu-n-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 + build-environment: linux-noble-xpu-n-py3.10 + docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" }, diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 36f603f70fde7..d9a1ba13d2b59 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -47,15 +47,15 @@ jobs: ]} secrets: inherit - linux-jammy-xpu-n-py3_10-build: - name: linux-jammy-xpu-n-py3.10 + linux-noble-xpu-n-py3_10-build: + name: linux-noble-xpu-n-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: sync-tag: linux-xpu-n-build runner_prefix: ${{ needs.get-label-type.outputs.label-type }} - build-environment: linux-jammy-xpu-n-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3 + build-environment: linux-noble-xpu-n-py3.10 + docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3 runner: linux.c7i.12xlarge test-matrix: | { include: [ @@ -74,17 +74,17 @@ jobs: ]} secrets: inherit - linux-jammy-xpu-n-py3_10-test: - name: linux-jammy-xpu-n-py3.10 + linux-noble-xpu-n-py3_10-test: + name: linux-noble-xpu-n-py3.10 uses: ./.github/workflows/_xpu-test.yml - needs: linux-jammy-xpu-n-py3_10-build + needs: linux-noble-xpu-n-py3_10-build permissions: id-token: write contents: read with: - build-environment: linux-jammy-xpu-n-py3.10 - docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }} + build-environment: linux-noble-xpu-n-py3.10 + docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }} secrets: inherit windows-xpu-n-1-build: diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp index 282f42f37a364..4d6cb1b81fac3 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QConv.cpp @@ -133,7 +133,7 @@ at::Tensor quantized_convolution( // supported in conv. mask_weight = weight_zero_points.numel() > 1 ? 1 : 0; if (groups > 1 && weight_zero_points.numel() > 1) - mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel) + mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel) dnnl::primitive_attr pattr; bool src_need_zp = (act_zero_point != 0); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index e1cc43350b2b6..d5c585c1e1f0b 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1941,6 +1941,7 @@ if(BUILD_TEST) foreach(test_src ${Caffe2_XPU_TEST_SRCS}) get_filename_component(test_name ${test_src} NAME_WE) add_executable(${test_name} "${test_src}") + torch_compile_options(${test_name}) target_link_libraries(${test_name} torch_library gtest_main) target_include_directories(${test_name} PRIVATE $) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) From 5bfce8f345356f2ade92c790c8e5d2ff9bbac555 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Thu, 6 Nov 2025 09:24:02 -0800 Subject: [PATCH 201/651] Unit test for torch.compile bmm dtype (#167140) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167140 Approved by: https://github.com/atalman, https://github.com/mlazos --- test/inductor/test_max_autotune.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 2f753b7ae0e69..e1d184e952596 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1913,6 +1913,29 @@ def mm_transpose_relu(a, b): # Check that contiguous transform was used FileCheck().check("contiguous_mm").run(code[0]) + @unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI") + @unittest.skipIf(TEST_WITH_ROCM, "out_dtype override only available on NVIDIA") + def test_bmm_out_dtype(self): + def f(a, b): + return torch.bmm(a, b, out_dtype=torch.float32) + + a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16) + with config.patch( + max_autotune=True, + max_autotune_gemm_backends="TRITON", + ): + compiled_f = torch.compile(f) + with self.assertRaisesRegex( + torch._inductor.exc.InductorError, + r"LoweringException: NoValidChoicesError: No choices to select", + ): + out, code = run_and_get_code(compiled_f, a, b) + + compiled_f = torch.compile(f) + out, code = run_and_get_code(compiled_f, a, b) + FileCheck().check("extern_kernels.bmm_dtype").run(code[0]) + def test_triton_template_generated_code_cache_key(self): generate_and_load_args = len( inspect.signature( From 192034c41b779c05d7e82d675226e45badcd216c Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 7 Nov 2025 23:10:20 +0800 Subject: [PATCH 202/651] [easy][dynamo][pytree] simplify pytree polyfill module by move out the guard-if (#167221) Move the guard-if in `polyfills.pytree` to `polyfills.loader` and dedent the code in the if-branch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167221 Approved by: https://github.com/Lucaskabela --- torch/_dynamo/polyfills/loader.py | 6 +- torch/_dynamo/polyfills/pytree.py | 1159 +++++++++++++++-------------- 2 files changed, 595 insertions(+), 570 deletions(-) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d348a422ff576..31479e9d86ce6 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -4,6 +4,8 @@ import importlib from typing import TYPE_CHECKING +import torch.utils._pytree as python_pytree + from .. import polyfills, trace_rules @@ -19,12 +21,14 @@ "itertools", "operator", "os", - "pytree", "struct", "sys", "fx", "tensor", ) +if python_pytree._cxx_pytree_dynamo_traceable: + POLYFILLED_MODULE_NAMES += ("pytree",) + POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) for submodule in POLYFILLED_MODULE_NAMES diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index b4de3200e2960..c01026ef30211 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -7,9 +7,12 @@ from collections import deque from dataclasses import dataclass, field from typing import Any, TYPE_CHECKING, TypeVar -from typing_extensions import TypeIs -import torch.utils._pytree as python_pytree +import optree +import optree._C +import optree.utils + +import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES from ..decorators import substitute_in_graph @@ -18,7 +21,9 @@ if TYPE_CHECKING: import builtins from collections.abc import Callable, Iterable, Mapping - from typing_extensions import Self + from typing_extensions import Self, TypeIs + + from torch.utils._cxx_pytree import PyTree __all__: list[str] = [] @@ -29,417 +34,472 @@ _VT = TypeVar("_VT") -if python_pytree._cxx_pytree_dynamo_traceable: - import optree - import optree._C - import optree.utils - - import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 +@substitute_in_graph( + optree._C.is_dict_insertion_ordered, + can_constant_fold_through=True, +) +def _(*args: Any, **kwargs: Any) -> bool: + # In namespace 'torch', the dictionary is always traversed in insertion order. + # This function returns True. + raise ValueError( + "Should not be called directly " + "because the original function will be called in the constant fold path." + ) - if TYPE_CHECKING: - from torch.utils._cxx_pytree import PyTree - @substitute_in_graph( - optree._C.is_dict_insertion_ordered, - can_constant_fold_through=True, +__name = "" +for __name in ( + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + "namedtuple_fields", + "structseq_fields", +): + __func = getattr(optree, __name) + globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( + __func.__python_implementation__ ) - def _(*args: Any, **kwargs: Any) -> bool: - # In namespace 'torch', the dictionary is always traversed in insertion order. - # This function returns True. - raise ValueError( - "Should not be called directly " - "because the original function will be called in the constant fold path." - ) + __all__ += [__name] # noqa: PLE0604 + del __func +del __name + + +@substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_is_leaf( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> bool: + if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): + return True + if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined] + return True + return False + + +@substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] +def tree_iter( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> Iterable[Any]: + stack = [tree] + while stack: + node = stack.pop() + if tree_is_leaf( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ): + yield node + continue - __name = "" - for __name in ( - "is_namedtuple", - "is_namedtuple_class", - "is_namedtuple_instance", - "is_structseq", - "is_structseq_class", - "is_structseq_instance", - "namedtuple_fields", - "structseq_fields", - ): - __func = getattr(optree, __name) - globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( - __func.__python_implementation__ + children, *_ = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - __all__ += [__name] # noqa: PLE0604 - del __func - del __name - - @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] - def tree_is_leaf( - tree: PyTree, - /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> bool: - if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): - return True - if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined] - return True - return False - - @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] - def tree_iter( - tree: PyTree, - /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> Iterable[Any]: - stack = [tree] - while stack: - node = stack.pop() - if tree_is_leaf( - node, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ): - yield node - continue + stack.extend(reversed(children)) - children, *_ = optree.tree_flatten_one_level( - node, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - stack.extend(reversed(children)) - - __all__ += ["tree_iter"] - - @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] - def tree_leaves( - tree: PyTree, - /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> list[Any]: - return list( - tree_iter( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) + +__all__ += ["tree_iter"] + + +@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_leaves( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> list[Any]: + return list( + tree_iter( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) + ) - __all__ += ["tree_leaves"] - - class _Asterisk(str): - __slots__ = () - - def __new__(cls) -> Self: - return super().__new__(cls, "*") - - def __repr__(self) -> str: - return "*" # no quotes - - _asterisk = _Asterisk() - del _Asterisk - - @dataclass(frozen=True) - class PyTreeSpec: - """Analog for :class:`optree.PyTreeSpec` in Python.""" - - _children: tuple[PyTreeSpec, ...] - _type: builtins.type | None - _metadata: Any - _entries: tuple[Any, ...] - _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None - none_is_leaf: bool - namespace: str - - num_nodes: int = field(init=False) - num_leaves: int = field(init=False) - num_children: int = field(init=False) - - def __post_init__(self) -> None: - if self._type is None: - assert len(self._children) == 0 - assert self._metadata is None - assert self._entries == () - assert self._unflatten_func is None - num_nodes = 1 - num_leaves = 1 - num_children = 0 - else: - assert callable(self._unflatten_func) - num_nodes = sum((spec.num_nodes for spec in self._children), start=1) - num_leaves = sum(spec.num_leaves for spec in self._children) - num_children = len(self._children) - - object.__setattr__(self, "num_nodes", num_nodes) - object.__setattr__(self, "num_leaves", num_leaves) - object.__setattr__(self, "num_children", num_children) - - def __repr__(self) -> str: - def helper(treespec: PyTreeSpec) -> str: - if treespec.is_leaf(): - assert treespec.type is None - return _asterisk - - assert treespec.type is not None - assert callable(treespec._unflatten_func) - children_representations = [ - helper(subspec) for subspec in treespec._children - ] - if ( - treespec.type in BUILTIN_TYPES - or (treespec.type is type(None) and not self.none_is_leaf) - or optree.is_namedtuple_class(treespec.type) - or optree.is_structseq_class(treespec.type) - ): - # pyrefly: ignore [bad-return] - return treespec._unflatten_func( - treespec._metadata, - children_representations, - ) - return ( - f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " - f"[{', '.join(children_representations)}])" - ) - inner = [ - str(helper(self)), - *(["NoneIsLeaf"] if self.none_is_leaf else []), - f"namespace={self.namespace!r}", - ] - return f"PyTreeSpec({', '.join(inner)})" +__all__ += ["tree_leaves"] - def __len__(self) -> int: - return self.num_leaves - @property - def type(self) -> builtins.type | None: - return self._type +class _Asterisk(str): + __slots__ = () - def is_leaf(self) -> bool: - return self.num_nodes == 1 and self.num_leaves == 1 + def __new__(cls) -> Self: + return super().__new__(cls, "*") - def children(self) -> list[PyTreeSpec]: - return list(self._children) + def __repr__(self) -> str: + return "*" # no quotes - def child(self, index: int) -> PyTreeSpec: - return self._children[index] - def entries(self) -> list[Any]: - return list(self._entries) +_asterisk = _Asterisk() +del _Asterisk - def entry(self, index: int) -> Any: - return self._entries[index] - def flatten_up_to(self, tree: PyTree) -> list[PyTree]: - def helper( - treespec: PyTreeSpec, - node: PyTree, - subtrees: list[PyTree], - ) -> None: - if treespec.is_leaf(): - subtrees.append(node) - return +@dataclass(frozen=True) +class PyTreeSpec: + """Analog for :class:`optree.PyTreeSpec` in Python.""" - node_type = type(node) - if treespec.type not in BUILTIN_TYPES: - # Always require custom node types to match exactly - if node_type != treespec.type: - raise ValueError( - f"Type mismatch; " - f"expected {treespec.type!r}, but got {node_type!r}.", - ) + _children: tuple[PyTreeSpec, ...] + _type: builtins.type | None + _metadata: Any + _entries: tuple[Any, ...] + _unflatten_func: Callable[[Any | None, Iterable[PyTree]], PyTree] | None + none_is_leaf: bool + namespace: str + num_nodes: int = field(init=False) + num_leaves: int = field(init=False) + num_children: int = field(init=False) + + def __post_init__(self) -> None: + if self._type is None: + assert len(self._children) == 0 + assert self._metadata is None + assert self._entries == () + assert self._unflatten_func is None + num_nodes = 1 + num_leaves = 1 + num_children = 0 + else: + assert callable(self._unflatten_func) + num_nodes = sum((spec.num_nodes for spec in self._children), start=1) + num_leaves = sum(spec.num_leaves for spec in self._children) + num_children = len(self._children) + + object.__setattr__(self, "num_nodes", num_nodes) + object.__setattr__(self, "num_leaves", num_leaves) + object.__setattr__(self, "num_children", num_children) + + def __repr__(self) -> str: + def helper(treespec: PyTreeSpec) -> str: + if treespec.is_leaf(): + assert treespec.type is None + return _asterisk + + assert treespec.type is not None + assert callable(treespec._unflatten_func) + children_representations = [ + helper(subspec) for subspec in treespec._children + ] + if ( + treespec.type in BUILTIN_TYPES + or (treespec.type is type(None) and not self.none_is_leaf) + or optree.is_namedtuple_class(treespec.type) + or optree.is_structseq_class(treespec.type) + ): + # pyrefly: ignore [bad-return] + return treespec._unflatten_func( + treespec._metadata, + children_representations, + ) + return ( + f"CustomTreeNode({treespec.type.__name__}[{treespec._metadata!r}], " + f"[{', '.join(children_representations)}])" + ) + + inner = [ + str(helper(self)), + *(["NoneIsLeaf"] if self.none_is_leaf else []), + f"namespace={self.namespace!r}", + ] + return f"PyTreeSpec({', '.join(inner)})" + + def __len__(self) -> int: + return self.num_leaves + + @property + def type(self) -> builtins.type | None: + return self._type + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def children(self) -> list[PyTreeSpec]: + return list(self._children) + + def child(self, index: int) -> PyTreeSpec: + return self._children[index] + + def entries(self) -> list[Any]: + return list(self._entries) + + def entry(self, index: int) -> Any: + return self._entries[index] + + def flatten_up_to(self, tree: PyTree) -> list[PyTree]: + def helper( + treespec: PyTreeSpec, + node: PyTree, + subtrees: list[PyTree], + ) -> None: + if treespec.is_leaf(): + subtrees.append(node) + return + + node_type = type(node) + if treespec.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != treespec.type: + raise ValueError( + f"Type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + + children, metadata, *_ = optree.tree_flatten_one_level( + node, + none_is_leaf=self.none_is_leaf, + namespace=self.namespace, + ) + if len(children) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(children)}.", + ) + if metadata != treespec._metadata: + raise ValueError( + f"Node context mismatch for custom node type {treespec.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + treespec.type in STANDARD_DICT_TYPES + and node_type in STANDARD_DICT_TYPES + ) + if not both_standard_dict and node_type != treespec.type: + raise ValueError( + f"Node type mismatch; " + f"expected {treespec.type!r}, but got {node_type!r}.", + ) + if len(node) != treespec.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {treespec.num_children}, but got {len(node)}.", + ) + + if both_standard_dict: + # dictionary types are compatible with each other + expected_keys = treespec.entries() + got_key_set = set(node) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + children = [node[key] for key in expected_keys] + else: + # node_type is treespec.type children, metadata, *_ = optree.tree_flatten_one_level( node, none_is_leaf=self.none_is_leaf, namespace=self.namespace, ) - if len(children) != treespec.num_children: + if ( + node_type is not deque # ignore mismatch of `maxlen` for deque + ) and metadata != treespec._metadata: raise ValueError( - f"Node arity mismatch; " - f"expected {treespec.num_children}, but got {len(children)}.", - ) - if metadata != treespec._metadata: - raise ValueError( - f"Node context mismatch for custom node type {treespec.type!r}.", - ) - else: - # For builtin dictionary types, we allow some flexibility - # Otherwise, we require exact matches - both_standard_dict = ( - treespec.type in STANDARD_DICT_TYPES - and node_type in STANDARD_DICT_TYPES - ) - if not both_standard_dict and node_type != treespec.type: - raise ValueError( - f"Node type mismatch; " - f"expected {treespec.type!r}, but got {node_type!r}.", - ) - if len(node) != treespec.num_children: - raise ValueError( - f"Node arity mismatch; " - f"expected {treespec.num_children}, but got {len(node)}.", + f"Node metadata mismatch for node type {treespec.type!r}; " + f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch ) - if both_standard_dict: - # dictionary types are compatible with each other - expected_keys = treespec.entries() - got_key_set = set(node) - expected_key_set = set(expected_keys) - if got_key_set != expected_key_set: - missing_keys = expected_key_set.difference(got_key_set) - extra_keys = got_key_set.difference(expected_key_set) - message = "" - if missing_keys: - message += f"; missing key(s): {missing_keys}" - if extra_keys: - message += f"; extra key(s): {extra_keys}" - raise ValueError(f"Node keys mismatch{message}.") - children = [node[key] for key in expected_keys] - else: - # node_type is treespec.type - children, metadata, *_ = optree.tree_flatten_one_level( - node, - none_is_leaf=self.none_is_leaf, - namespace=self.namespace, - ) - if ( - node_type - is not deque # ignore mismatch of `maxlen` for deque - ) and metadata != treespec._metadata: - raise ValueError( - f"Node metadata mismatch for node type {treespec.type!r}; " - f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch - ) - - for subtree, subspec in zip(children, treespec._children): - helper(subspec, subtree, subtrees) - - subtrees: list[PyTree] = [] - helper(self, tree, subtrees) - return subtrees - - def unflatten(self, leaves: Iterable[Any]) -> PyTree: - if not isinstance(leaves, (list, tuple)): - leaves = list(leaves) - if len(leaves) != self.num_leaves: - raise ValueError( - f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " - f"but the spec refers to a pytree that holds {self.num_leaves} " - f"items ({self}).", - ) - if self.is_leaf(): - return leaves[0] - - # Recursively unflatten the children - start = 0 - end = 0 - subtrees = [] - for subspec in self._children: - end += subspec.num_leaves - subtrees.append(subspec.unflatten(leaves[start:end])) - start = end + for subtree, subspec in zip(children, treespec._children): + helper(subspec, subtree, subtrees) - assert callable(self._unflatten_func) - return self._unflatten_func(self._metadata, subtrees) + subtrees: list[PyTree] = [] + helper(self, tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + # Recursively unflatten the children + start = 0 + end = 0 + subtrees = [] + for subspec in self._children: + end += subspec.num_leaves + subtrees.append(subspec.unflatten(leaves[start:end])) + start = end + + assert callable(self._unflatten_func) + return self._unflatten_func(self._metadata, subtrees) + + +def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: + return isinstance(obj, PyTreeSpec) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_leaf, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_leaf( + *, + none_is_leaf: bool = False, + namespace: str = "", # unused +) -> PyTreeSpec: + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace="", + ) - def _is_pytreespec_instance(obj: Any, /) -> TypeIs[PyTreeSpec]: - return isinstance(obj, PyTreeSpec) - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_leaf, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_tuple, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_tuple( + iterable: Iterable[PyTreeSpec] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + children = tuple(iterable) + if any(not _is_pytreespec_instance(child) for child in children): + raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") + if any(child.none_is_leaf != none_is_leaf for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {children!r}.", + ) + if any(child.namespace not in (namespace, "") for child in children): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {children!r}.", + ) + handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined] + assert handler is not None + return PyTreeSpec( + tuple(children), + tuple, + None, + tuple(range(len(children))), + handler.unflatten_func, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - def treespec_leaf( - *, - none_is_leaf: bool = False, - namespace: str = "", # unused - ) -> PyTreeSpec: - return PyTreeSpec( - (), - None, - None, - (), - None, - none_is_leaf=none_is_leaf, - namespace="", + + +@substitute_in_graph( # type: ignore[arg-type] + optree.treespec_dict, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def treespec_dict( + mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = "", + **kwargs: PyTreeSpec, +) -> PyTreeSpec: + dct = dict(mapping, **kwargs) + if any(not _is_pytreespec_instance(child) for child in dct.values()): + raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") + if any(child.none_is_leaf != none_is_leaf for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `none_is_leaf` value " + f"as the parent; expected {none_is_leaf}, got: {dct!r}.", + ) + if any(child.namespace not in (namespace, "") for child in dct.values()): + raise ValueError( + "All children PyTreeSpecs must have the same `namespace` value " + f"as the parent; expected {namespace!r}, got: {dct!r}.", ) - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_tuple, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, + ( + children, + metadata, + entries, + unflatten_func, + ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] + dct, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, ) - def treespec_tuple( - iterable: Iterable[PyTreeSpec] = (), - /, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> PyTreeSpec: - children = tuple(iterable) - if any(not _is_pytreespec_instance(child) for child in children): - raise ValueError(f"Expected a tuple of PyTreeSpecs, got: {children!r}.") - if any(child.none_is_leaf != none_is_leaf for child in children): - raise ValueError( - "All children PyTreeSpecs must have the same `none_is_leaf` value " - f"as the parent; expected {none_is_leaf}, got: {children!r}.", - ) - if any(child.namespace not in (namespace, "") for child in children): - raise ValueError( - "All children PyTreeSpecs must have the same `namespace` value " - f"as the parent; expected {namespace!r}, got: {children!r}.", - ) - handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined] - assert handler is not None - return PyTreeSpec( - tuple(children), - tuple, - None, - tuple(range(len(children))), - handler.unflatten_func, + return PyTreeSpec( + tuple(children), # type: ignore[arg-type] + dict, + metadata, + entries, + unflatten_func, # type: ignore[arg-type] + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[Any], PyTreeSpec]: + def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: + if tree_is_leaf( + node, + is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, - ) - - @substitute_in_graph( # type: ignore[arg-type] - optree.treespec_dict, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def treespec_dict( - mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), - /, - *, - none_is_leaf: bool = False, - namespace: str = "", - **kwargs: PyTreeSpec, - ) -> PyTreeSpec: - dct = dict(mapping, **kwargs) - if any(not _is_pytreespec_instance(child) for child in dct.values()): - raise ValueError(f"Expected a dictionary of TreeSpecs, got: {dct!r}.") - if any(child.none_is_leaf != none_is_leaf for child in dct.values()): - raise ValueError( - "All children PyTreeSpecs must have the same `none_is_leaf` value " - f"as the parent; expected {none_is_leaf}, got: {dct!r}.", - ) - if any(child.namespace not in (namespace, "") for child in dct.values()): - raise ValueError( - "All children PyTreeSpecs must have the same `namespace` value " - f"as the parent; expected {namespace!r}, got: {dct!r}.", + ): + leaves.append(node) + return PyTreeSpec( + (), + None, + None, + (), + None, + none_is_leaf=none_is_leaf, + namespace=namespace, ) ( @@ -447,206 +507,167 @@ def treespec_dict( metadata, entries, unflatten_func, - ) = optree.tree_flatten_one_level( # type: ignore[assignment,var-annotated] - dct, # type: ignore[arg-type] + ) = optree.tree_flatten_one_level( + node, + is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, ) + + # Recursively flatten the children + subspecs = tuple(helper(child, leaves) for child in children) return PyTreeSpec( - tuple(children), # type: ignore[arg-type] - dict, + subspecs, + type(node), metadata, entries, unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, + ) # type: ignore[arg-type] + + leaves: list[Any] = [] + treespec = helper(tree, leaves) + return leaves, treespec + + +__all__ += ["tree_flatten"] + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_structure, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_structure( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTreeSpec: + return tree_flatten( # type: ignore[return-value] + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + )[1] + + +__all__ += ["tree_structure"] + + +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_unflatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " + f"PyTreeSpec but got item of type {type(treespec)}." ) - - @substitute_in_graph( # type: ignore[arg-type] - optree.tree_flatten, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def tree_flatten( - tree: PyTree, - /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> tuple[list[Any], PyTreeSpec]: - def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: - if tree_is_leaf( - node, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ): - leaves.append(node) - return PyTreeSpec( - (), - None, - None, - (), - None, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - - ( - children, - metadata, - entries, - unflatten_func, - ) = optree.tree_flatten_one_level( - node, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - - # Recursively flatten the children - subspecs = tuple(helper(child, leaves) for child in children) - return PyTreeSpec( - subspecs, - type(node), - metadata, - entries, - unflatten_func, # type: ignore[arg-type] - none_is_leaf=none_is_leaf, - namespace=namespace, - ) # type: ignore[arg-type] - - leaves: list[Any] = [] - treespec = helper(tree, leaves) - return leaves, treespec - - __all__ += ["tree_flatten"] - - @substitute_in_graph( # type: ignore[arg-type] - optree.tree_structure, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def tree_structure( - tree: PyTree, - /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, - none_is_leaf: bool = False, - namespace: str = "", - ) -> PyTreeSpec: - return tree_flatten( # type: ignore[return-value] - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - )[1] - - __all__ += ["tree_structure"] - - @substitute_in_graph( # type: ignore[arg-type] - optree.tree_unflatten, - # We need to disable constant folding here because we want the function to reference the - # PyTreeSpec class defined above, not the one in the C++ module. - can_constant_fold_through=False, - ) - def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: - if not _is_pytreespec_instance(treespec): - raise TypeError( - f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " - f"PyTreeSpec but got item of type {type(treespec)}." - ) - return treespec.unflatten(leaves) - - __all__ += ["tree_unflatten"] - - @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] - def tree_map( - func: Callable[..., Any], - tree: PyTree, - /, - *rests: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - none_is_leaf: bool = False, - namespace: str = "", - ) -> PyTree: - leaves, treespec = tree_flatten( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - return treespec.unflatten(map(func, *flat_args)) - - __all__ += ["tree_map"] - - @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] - def tree_map_( - func: Callable[..., Any], - tree: PyTree, - /, - *rests: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - none_is_leaf: bool = False, - namespace: str = "", - ) -> PyTree: - leaves, treespec = tree_flatten( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable - return tree - - __all__ += ["tree_map_"] - - _none_registration = optree.register_pytree_node.get(type(None)) - assert _none_registration is not None - - @substitute_in_graph( # type: ignore[arg-type] - _none_registration.unflatten_func, - can_constant_fold_through=True, - skip_signature_check=True, - ) - def none_unflatten(_: None, children: Iterable[_T], /) -> None: - if len(list(children)) != 0: - raise ValueError("Expected no children.") - return None - - with optree.dict_insertion_ordered(False, namespace="torch"): - _dict_registration = optree.register_pytree_node.get(dict) - assert _dict_registration is not None - - @substitute_in_graph( # type: ignore[arg-type] - _dict_registration.flatten_func, - can_constant_fold_through=True, - skip_signature_check=True, + return treespec.unflatten(leaves) + + +__all__ += ["tree_unflatten"] + + +@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_map( + func: Callable[..., Any], + tree: PyTree, + /, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTree: + leaves, treespec = tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - def dict_flatten( - dct: dict[_KT, _VT], / - ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: - sorted_keys = optree.utils.total_order_sorted(dct) - values = [dct[key] for key in sorted_keys] - original_keys = list(dct) - return values, (original_keys, sorted_keys), tuple(sorted_keys) - - @substitute_in_graph( # type: ignore[arg-type] - _dict_registration.unflatten_func, - can_constant_fold_through=True, - skip_signature_check=True, + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +__all__ += ["tree_map"] + + +@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + /, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTree: + leaves, treespec = tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, ) - def dict_unflatten( - metadata: tuple[list[_KT], list[_KT]], - values: Iterable[_VT], - /, - ) -> dict[_KT, _VT]: - original_keys, sorted_keys = metadata - d = dict.fromkeys(original_keys) - d.update(zip(sorted_keys, values)) - return d # type: ignore[return-value] + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable + return tree + + +__all__ += ["tree_map_"] + +_none_registration = optree.register_pytree_node.get(type(None)) +assert _none_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _none_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def none_unflatten(_: None, children: Iterable[_T], /) -> None: + if len(list(children)) != 0: + raise ValueError("Expected no children.") + return None + + +with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_flatten( + dct: dict[_KT, _VT], / +) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + +@substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, +) +def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, +) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values)) + return d # type: ignore[return-value] From 285748e838b411433030785805eba8b5f7bb043e Mon Sep 17 00:00:00 2001 From: chenlang Date: Fri, 7 Nov 2025 16:01:30 +0000 Subject: [PATCH 203/651] fix the cpp_builder error under riscv (#167071) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **fix the cpp_builder error under riscv** `g++: error: ‘-march=native’: ISA string must begin with rv32 or rv64` (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 1718, in build (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] run_compile_cmd(build_cmd, cwd=_build_tmp_dir) (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 401, in run_compile_cmd (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] _run_compile_cmd(cmd_line, cwd) (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] File "/usr/local/lib64/python3.11/site-packages/torch/_inductor/cpp_builder.py", line 396, in _run_compile_cmd (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] raise exc.CppCompileError(cmd, output) from e (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] torch._inductor.exc.InductorError: CppCompileError: C++ compile error (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] Command: (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] g++ /tmp/tmpv8qz53jp/header.hpp -D TORCH_INDUCTOR_CPP_WRAPPER -D STANDALONE_TORCH_HEADER -D C10_USING_CUSTOM_GENERATED_MACROS -fPIC -O3 -DNDEBUG -fno-trapping-math -funsafe-math-optimizations -ffinite-math-only -fno-signed-zeros -fno-math-errno -fexcess-precision=fast -fno-finite-math-only -fno-unsafe-math-optimizations -ffp-contract=off -fno-tree-loop-vectorize -march=native -Wall -std=c++17 -Wno-unused-variable -Wno-unknown-pragmas -fopenmp -I/usr/include/python3.11 -I/usr/local/lib64/python3.11/site-packages/torch/include -I/usr/local/lib64/python3.11/site-packages/torch/include/torch/csrc/api/include -D_GLIBCXX_USE_CXX11_ABI=1 -E -P -o /tmp/tmpv8qz53jp/header.i (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] Output: (EngineCore_DP0 pid=14414) ERROR 11-04 18:36:01 [core.py:779] g++: error: ‘-march=native’: ISA string must begin with rv32 or rv64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167071 Approved by: https://github.com/malfet --- torch/_inductor/cpp_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 8e072178099c6..9b2444fb5ef19 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -913,6 +913,10 @@ def _get_optimization_cflags( if not config.is_fbcode(): if platform.machine() == "ppc64le": cflags.append("mcpu=native") + elif platform.machine() == "riscv64": + cflags.append("march=rv64gc") + elif platform.machine() == "riscv32": + cflags.append("march=rv32gc") else: cflags.append("march=native") From 694592ac1ed5e86c3cf9bdd5f503d1ba0dfccf4c Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Fri, 7 Nov 2025 16:12:47 +0000 Subject: [PATCH 204/651] Move enrich_profiler_metadata config import out of gm.recompile() (#167114) Fixes T243967987 Move `enrich_profiler_metadata` from `torch._dynamo.config` to `torch.fx.experimental._config`. We cannot import anything inside recompile(), it made some perf regress internally. We move the config so we can import it at the top of `graph_module.py` without causing any circular import. We also cannot delete the old config right now because some internal tests rely on copies of the old `graph_module.py` cpp file in unit tests. But I think we should be able to delete the old config soon after this PR lands. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167114 Approved by: https://github.com/angelayi --- test/test_cuda.py | 2 +- test/test_fx.py | 6 +++--- torch/_dynamo/config.py | 7 ++----- torch/fx/experimental/_config.py | 8 +++++++- torch/fx/graph_module.py | 10 ++++++---- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 1d3ff12b4b6ea..5842b0eda7422 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -7486,7 +7486,7 @@ def collect_frames( return fx_frames @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_fx_memory_profiler_augmentation(self): """Test that memory snapshots are augmented with FX debug information.""" diff --git a/test/test_fx.py b/test/test_fx.py index 0b177a96ae0b0..71299ddb2400d 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4251,7 +4251,7 @@ def fn(a, b, c, d): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_stack_trace_augmentation(self): """ Test that map_recorded_events_to_aten_ops_with_stack_trace correctly @@ -4307,7 +4307,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_multiple_modules(self): """ Test that multiple compiled modules under the same profiler session @@ -4351,7 +4351,7 @@ def forward(self, x): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @skipIfRocm - @torch._dynamo.config.patch("enrich_profiler_metadata", True) + @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) def test_profiler_nested_graph_modules(self): """ Test that nested graph modules (e.g., graph modules calling subgraphs) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 8682ac1cb3a44..0355c4670344e 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -740,11 +740,8 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None -# Experimental: If True, graph module will register fx metadata during recompile() -enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] - default=False, - env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", -) +# Deprecated! Please use the config in torch/fx/experimental/_config instead. +enrich_profiler_metadata: bool = False if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index ce4296b6410c9..a537978db3834 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -2,6 +2,8 @@ import sys from typing import Optional +from torch.utils._config_module import Config, install_config_module + # [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. no_data_dependent_graph_break = ( @@ -100,7 +102,11 @@ # Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. skip_dtype_check_in_meta_registrations = False -from torch.utils._config_module import install_config_module +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) install_config_module(sys.modules[__name__]) diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..ab33d7bf321c9 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -20,6 +20,7 @@ from torch.package import Importer, PackageExporter, PackageImporter, sys_importer from ._compatibility import compatibility +from .experimental import _config as fx_experimental_config from .graph import ( _BoxedCodeGen, _custom_builtins, @@ -858,14 +859,15 @@ def recompile(self) -> PythonCode: called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ + # Do not import anything inside recompile, it might slow down the + # function and cause perf regression. Import outside of the method instead. if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - from torch._dynamo import config as dynamo_config - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata + root_module="self", + record_func=fx_experimental_config.enrich_profiler_metadata, ) self._code = python_code.src self._lineno_map = python_code._lineno_map @@ -874,7 +876,7 @@ def recompile(self) -> PythonCode: cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - if dynamo_config.enrich_profiler_metadata: + if fx_experimental_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation node_metadata: dict[int, dict[str, Any]] = {} for i, node in enumerate(self._graph.nodes): From 12860892f825d5d8d73c9ea18549dc008f7977a0 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 16:45:23 +0000 Subject: [PATCH 205/651] Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167182)" This reverts commit 77b70970f70d53de71b9703ad4c3199d714c535a. Reverted https://github.com/pytorch/pytorch/pull/167182 on behalf of https://github.com/NikhilAPatel due to breaks local source build ([comment](https://github.com/pytorch/pytorch/pull/167182#issuecomment-3503598156)) --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 - setup.py | 34 -- test/inductor/test_cutedsl_grouped_mm.py | 154 -------- torch/_inductor/config.py | 4 - torch/_inductor/kernel/mm_common.py | 7 - torch/_inductor/kernel/mm_grouped.py | 90 ++--- .../templates/cutedsl_mm_grouped.py.jinja | 333 ------------------ .../_inductor/template_heuristics/cutedsl.py | 141 -------- torch/_inductor/utils.py | 78 ---- 10 files changed, 33 insertions(+), 811 deletions(-) delete mode 100644 test/inductor/test_cutedsl_grouped_mm.py delete mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja delete mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae2578758939..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index 3b4323051073a..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index dd8a52cbeb7c7..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a54099..0000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e00cacb59abe6..3eaa840961fa8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -550,10 +550,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2afc..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index c81ec607661bc..881c14fd43d0d 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,13 +1,11 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl -from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -24,13 +22,11 @@ get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) @@ -517,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -723,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -798,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80f..0000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a271..0000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 2b7a9541aa875..11be081db1be7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL backend is enabled - 2. CuTeDSL is available - 3. We are on a blackwell arch - 4. The dtype is bf16 - 5. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - if not _use_autotune_backend("CUTEDSL"): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type): - return False - - if not is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 86db4de10f7c81ee2d36bbdbfb9cf1573a01b95e Mon Sep 17 00:00:00 2001 From: Sanket Jayant Purandare Date: Fri, 7 Nov 2025 17:11:14 +0000 Subject: [PATCH 206/651] [PP] PP Runtime Features for supporting Graph Based execution (#167277) Allow overriding UNSHARD, RESHARD and REDUCE_GRAD actions. Enable running pp backward without torch.grad.is_enabled(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/167277 Approved by: https://github.com/wconstab --- torch/distributed/pipelining/schedules.py | 29 ++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d84857ef474af..abc007a8166db 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1485,6 +1485,7 @@ def __init__( output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, use_full_backward: Optional[bool] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): # Init parent super().__init__( @@ -1517,6 +1518,11 @@ def __init__( # This will be set during init of derived schedules self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + # When using a custom backward function, we may or may not need autograd to be used + # for the backward pass. This flag is used to determine whether or torch.is_grad_enabled() + # check should be performed before the step function. + self._backward_requires_autograd = backward_requires_autograd + if use_full_backward is not None: logger.warning( "Deprecation warning: 'use_full_backward' is no longer supported. " @@ -1609,7 +1615,11 @@ def step( losses: a list to store the losses for each microbatch. return_outputs: whether to return the outputs from the last stage. """ - if self._has_backward and not torch.is_grad_enabled(): + if ( + self._has_backward + and self._backward_requires_autograd + and not torch.is_grad_enabled() + ): raise RuntimeError( "step() requires gradients to be enabled for backward computation; " "it should not be used under torch.no_grad() context. " @@ -1891,7 +1901,7 @@ def register_custom_function( Args: computation_type: The computation type for which to register the custom function custom_function: The function to execute when this computation type is encountered. - Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None + Must have signature: (action: _Action, ctx: _PipelineContext) -> None """ # Ensure that the computation type is valid if computation_type not in ( @@ -1900,10 +1910,13 @@ def register_custom_function( BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, + UNSHARD, + RESHARD, + REDUCE_GRAD, ): raise ValueError( f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \ -BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported." + BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported." ) # Check if computation_type is already registered @@ -2296,6 +2309,7 @@ def __init__( loss_fn: Optional[Union[Callable, _Loss]] = None, output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): super().__init__( stages=stages, @@ -2303,6 +2317,7 @@ def __init__( loss_fn=loss_fn, output_merge_spec=output_merge_spec, scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, ) # 1. Create the pipeline_order (all ranks do this calculation) @@ -2510,6 +2525,7 @@ def __init__( kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): self.pp_group_size = stages[0].group_size super().__init__( @@ -2520,6 +2536,7 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, ) self.n_local_stages = len(stages) self.rank = stages[0].group_rank @@ -2622,6 +2639,7 @@ def __init__( kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): # TODO: we dont support input/weight backward split with torch.compile _check_torch_compile_compatibility(stages, self.__class__.__name__) @@ -2634,6 +2652,7 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, ) self.n_local_stages = len(stages) self.rank = stages[0].group_rank @@ -2819,6 +2838,7 @@ def __init__( kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): # TODO: we dont support input/weight backward split with torch.compile _check_torch_compile_compatibility(stages, self.__class__.__name__) @@ -2831,6 +2851,7 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, ) self.stage_index_to_group_rank = generate_stage_to_rank_mapping( self.pp_group_size, self._num_stages, style="v" @@ -2995,6 +3016,7 @@ def __init__( kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, scale_grads: bool = True, + backward_requires_autograd: bool = True, ): # TODO: we dont support input/weight backward split with torch.compile _check_torch_compile_compatibility(stages, self.__class__.__name__) @@ -3007,6 +3029,7 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, scale_grads=scale_grads, + backward_requires_autograd=backward_requires_autograd, ) self.stage_index_to_group_rank = generate_stage_to_rank_mapping( self.pp_group_size, self._num_stages, style="v" From ccc8c117dcf83a4096fde0b89e6e038f2605316f Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Wed, 5 Nov 2025 13:28:09 -0800 Subject: [PATCH 207/651] Codeowner/Labeler updates post-Blas-reorgs (#167130) Summary: Previous PRs have split out scaled/grouped Blas routines into their own files. This updates the codeowners and labeler to reflect those changes. Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/167130 Approved by: https://github.com/drisspg --- .github/labeler.yml | 9 ++++++--- CODEOWNERS | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 7b47b9fefb5dc..246ddd8614396 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -138,7 +138,8 @@ - test/test_matmul_cuda.py - test/test_scaled_matmul_cuda.py - test/inductor/test_fp8.py -- aten/src/ATen/native/cuda/Blas.cpp +- aten/src/ATen/native/cuda/*Blas.cpp +- aten/src/ATen/cuda/CUDA*Blas.* - torch/**/*cublas* - torch/_inductor/kernel/mm.py - test/inductor/test_max_autotune.py @@ -148,7 +149,8 @@ - test/test_matmul_cuda.py - test/test_scaled_matmul_cuda.py - test/inductor/test_fp8.py -- aten/src/ATen/native/cuda/Blas.cpp +- aten/src/ATen/native/cuda/*Blas.cpp +- aten/src/ATen/cuda/CUDA*Blas.* - torch/**/*cublas* - torch/_inductor/kernel/mm.py - test/inductor/test_max_autotune.py @@ -158,7 +160,8 @@ - test/test_matmul_cuda.py - test/test_scaled_matmul_cuda.py - test/inductor/test_fp8.py -- aten/src/ATen/native/cuda/Blas.cpp +- aten/src/ATen/native/cuda/*Blas.cpp +- aten/src/ATen/cuda/CUDA*Blas.* - torch/_inductor/kernel/mm.py - test/inductor/test_max_autotune.py - third_party/fbgemm diff --git a/CODEOWNERS b/CODEOWNERS index cc249dc4f43a2..137031066090e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A /test/inductor/test_flex_attention.py @drisspg /test/inductor/test_flex_decoding.py @drisspg -# Low Precision GEMMs +# Low Precision & Grouped GEMMs /aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58 +/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58 +/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58 /aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58 /aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58 +/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58 +/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58 /test/test_scaled_matmul_cuda.py @drisspg @slayton58 From b62935d1a53f5963f74119bdec64fdcded2bbcb6 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 7 Nov 2025 06:00:42 -0800 Subject: [PATCH 208/651] fix alpha beta in decomp (#167317) fix for https://github.com/pytorch/pytorch/issues/167313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167317 Approved by: https://github.com/zou3519 ghstack dependencies: #161404 --- test/inductor/test_pattern_matcher.py | 37 ++++++++++++++++++++++++++ torch/_inductor/fx_passes/post_grad.py | 22 +++++++++++---- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 4b8c866b9c291..4d0ad99978888 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -1217,6 +1217,43 @@ def fn2(inp, a, b): _, (code) = run_and_get_code(fn2, args[0], args[1], args[2]) FileCheck().check_not("extern_kernels.addmm(").run(code[0]) + def test_addmm_alpha_beta_with_pointwise(self): + # Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops + # See https://github.com/pytorch/pytorch/issues/167313 + x = torch.rand(2, device=GPU_TYPE) + a = torch.rand(2, 3, device=GPU_TYPE) + b = torch.rand(3, 2, device=GPU_TYPE) + + def f(x, a, b): + return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2)) + + fc = torch.compile(f) + + expected = f(x, a, b) + actual = fc(x, a, b) + + # The compiled version should produce the same result as eager + torch.testing.assert_close(actual, expected) + + # Verify that addmm is unfused (should not use extern_kernels.addmm) + # The pattern should be replaced with beta * x + alpha * (a @ b) + _, (code) = run_and_get_code(fc, x, a, b) + FileCheck().check_not("extern_kernels.addmm(").run(code[0]) + + # Test with alpha=1, beta=1 (default) - should also unfuse + def f_default(x, a, b): + return torch.nn.functional.relu(torch.addmm(x, a, b)) + + fc_default = torch.compile(f_default) + expected_default = f_default(x, a, b) + actual_default = fc_default(x, a, b) + + torch.testing.assert_close(actual_default, expected_default) + + # Should unfuse and not use extern_kernels.addmm + _, (code) = run_and_get_code(fc_default, x, a, b) + FileCheck().check_not("extern_kernels.addmm(").run(code[0]) + def test_serialized_patterns_up_to_date(self): import torch.utils._pytree as pytree from torch._inductor.fx_passes import joint_graph diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 9808c6944e13c..9c7c01c785f4e 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match): @register_graph_pattern( - CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + CallFunction( + aten.addmm, + KeywordArg("inp"), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), # pyrefly: ignore [bad-argument-type] pass_dict=pass_patterns[2], extra_check=should_prefer_unfused_addmm, ) -def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): - def repl(inp, x1, x2): - return x1 @ x2 + inp +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta): + def repl(inp, x1, x2, alpha, beta): + mm_result = x1 @ x2 + if alpha != 1: + mm_result = alpha * mm_result + if beta != 1: + inp = beta * inp + return inp + mm_result # pyrefly: ignore [bad-argument-type] - match.replace_by_example(repl, [inp, mat1, mat2]) + match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta]) def is_valid_addmm_fusion(match): From 724cd32b0cab03c54fd1f288f7555cfe0544c531 Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Fri, 7 Nov 2025 17:48:17 +0000 Subject: [PATCH 209/651] [PT2 Compiler] Add flag in dynamo disable wrapper to indicate reursive disable (#165790) Summary: After torch._dynamo.disable is applied, wrapped method does not have any flag to indicate whether it was disabled recursively or not. This flag is needed if to preserve dynamo disable methods in torch.export-ed model Test Plan: ``` buck test mode/opt caffe2/test/dynamo:test_dynamo -- 'test_disable_recursive_flags' ```` https://www.internalfb.com/intern/testinfra/testrun/7599824674075603 Differential Revision: D84949143 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165790 Approved by: https://github.com/angelayi, https://github.com/williamwen42 --- test/dynamo/test_decorators.py | 46 ++++++++++++++++++++++++++++++++++ torch/_dynamo/__init__.py | 2 ++ torch/_dynamo/decorators.py | 11 ++++++++ torch/_dynamo/eval_frame.py | 2 ++ 4 files changed, 61 insertions(+) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 0eb21c9cef068..68a10360284dc 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -2109,6 +2109,52 @@ def outer_f2(x): with self.assertRaises(Unsupported): outer_f2(inp) + def test_disable_recursive_flags(self): + class SimpleLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer0 = torch.nn.Linear(4, 4) + + def forward(self, inp): + return self.layer0(torch.sigmoid(inp)) + + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer0 = SimpleLinear() + self.layer1 = torch.nn.Linear(4, 4) + + def forward(self, inp): + z = self.layer0(torch.sin(inp)) + return self.layer1(z) + + for recursive_flag in [True, False]: + model = SimpleModel() + other_model = SimpleModel() + + model.forward = torch._dynamo.disable( + model.forward, + recursive=recursive_flag, + ) + self.assertEqual( + torch._dynamo.is_dynamo_disable_recursive(model.forward), + recursive_flag, + ) + + other_model = torch._dynamo.disable(other_model, recursive=recursive_flag) + self.assertEqual( + torch._dynamo.is_dynamo_disable_recursive( + other_model.forward + if isinstance(other_model, torch.nn.Module) + else other_model + ), + recursive_flag, + ) + + # check the model is compilable + torch.compile(model) + torch.compile(other_model) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 28a77d20ea3b0..de097edf87752 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -32,6 +32,7 @@ error_on_graph_break, forbid_in_graph, graph_break, + is_dynamo_disable_recursive, mark_dynamic, mark_static, mark_static_address, @@ -87,6 +88,7 @@ "forbid_in_graph", "graph_break", "is_compiling", + "is_dynamo_disable_recursive", "list_backends", "lookup_backend", "mark_dynamic", diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 144f0ea7eeefa..87becc8b8b1b2 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -96,6 +96,7 @@ def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]: nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined] # pyrefly: ignore [bad-return] return nonrecursive_disable_wrapper @@ -1023,3 +1024,13 @@ def error_on_graph_break( The default value of torch.compile's `error_on_graph_break` setting is False. """ return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break) + + +def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]: + """ + Check if a method is marked as `dynamo_disable` recursively. It returns: + - True if disable(recursive=True) + - False if disable(recursive=False) + - None if method is not a disable decorator + """ + return getattr(method, "_torchdynamo_disable_recursive", None) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 9ff4ae46523c3..e93e7ace7395e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1155,6 +1155,8 @@ def _fn(*args: Any, **kwargs: Any) -> Any: # of decorators. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined] + _fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined] + return _fn def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]: From 3c2409c4653b75864ce1a82ba336aecad21e62ac Mon Sep 17 00:00:00 2001 From: Jan Wieczorek Date: Fri, 7 Nov 2025 17:52:50 +0000 Subject: [PATCH 210/651] Refactor recursive call of collect_temp_source (#166714) Recursive function call creates a reference cycle: closure <- function <- cell inside closure Capturing self (PyCodegen instance) in same closure prolongs it's life until next gc.collect() which might result in worse resource management After the introduction of e9209e0 OOM issues has been observed. Looking for reference cycles one has been uncovered that would result in the prolonging lifetime of tensors. As the result of that OOM issues might occur. Such a dependency chain has been uncovered: image At the end of it a reference cycle can be found that consists of a closure for function collect_temp_source, the function itself, and a cell object inside closure that would point to the function due to the recursive call. This issue can either be resolved by removing recurrency or removing PyCodegen instance from the closure. Another precaution that can be made is to explicitly empty f_locals dict. This way we cut the tensor from the chain leading to reference cycle. Fixes #166721 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166714 Approved by: https://github.com/Lucaskabela, https://github.com/Skylion007, https://github.com/jeromean, https://github.com/williamwen42, https://github.com/mlazos --- torch/_dynamo/codegen.py | 41 ++++++++++++++++++---------------- torch/_dynamo/convert_frame.py | 1 + 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 1861b20105265..cf76243b98ddc 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -15,7 +15,7 @@ import re import sys import types -from collections import Counter +from collections import Counter, deque from collections.abc import Callable, Iterable from typing import Any, Optional, TYPE_CHECKING, Union @@ -597,32 +597,35 @@ def make_call_generated_code(self, fn_name: str) -> None: graphargs = self.tx.output.graphargs - seen_sources: OrderedSet[Source] = OrderedSet() - - def collect_temp_source(source: Source) -> None: - if source in seen_sources: - # This source is used at least twice, so it can be reused - self.mark_source_temp(source) - # Dont trace source further. This prevents us from marking too - # many nodes as temp sources. - return - - seen_sources.add(source) - + def extract_nested_sources(source: Source) -> list[Source]: + nested_sources: list[Source] = [] if isinstance(source, ChainedSource): - collect_temp_source(source.base) - + nested_sources.append(source.base) if isinstance(source, DictGetItemSource) and isinstance( source.index, Source ): - collect_temp_source(source.index) + nested_sources.append(source.index) + return nested_sources + + def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None: + seen_sources: OrderedSet[Source] = OrderedSet() + while sources: + current_source = sources.popleft() + if current_source in seen_sources: + # This source is used at least twice, so it can be reused + codegen.mark_source_temp(current_source) + # Dont trace source further. This prevents us from marking too + # many nodes as temp sources. + continue + seen_sources.add(current_source) + sources.extend(extract_nested_sources(current_source)) # Collect all the sources that are used more than once, so that we can # generate tmp variables in the generated pre-graph bytecode. This # essentially implements CSE. - for arg in graphargs: - if arg.source is not None: - collect_temp_source(arg.source) + collect_temp_sources( + deque([arg.source for arg in graphargs if arg.source is not None]), self + ) cm_var = None if config.record_runtime_overhead: diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 4439c7dc09efe..8cf4ab8954d5a 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -828,6 +828,7 @@ def run_tracer() -> None: raise finally: tracer.output.call_cleanup_hooks() + tracer.f_locals = {} try: run_tracer() From 69784a0dbe11b40c9aca73e47827b7e36028a0c2 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 7 Nov 2025 23:51:24 +0800 Subject: [PATCH 211/651] [dynamo][pytree] add polyfills for optree path APIs (#167211) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167211 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167221 --- test/dynamo/test_error_messages.py | 27 ++-- torch/_dynamo/polyfills/pytree.py | 220 +++++++++++++++++++++++------ 2 files changed, 184 insertions(+), 63 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index df8364e78e40d..1cf841c4947e1 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -423,33 +423,28 @@ def test_optree_graph_break_message(self): @torch.compile(backend="eager") def fn(x): - d = {"a": 1} - optree.tree_flatten_with_path(d) - return torch.sin(x) + tree = {"a": x, "b": (x - 1, 2 * x)} + sin, cos = optree.tree_transpose_map( + lambda x: (torch.sin(x), torch.cos(x)), + tree, + ) + return sin, cos def post_munge(s): - s = re.sub( - r"optree\.\S*\.flatten_with_path", - "optree..flatten_with_path", - s, - ) - return re.sub( - r"qualname: \S*flatten_with_path", - "qualname: .flatten_with_path", - s, - ) + s = re.sub(r"optree\.\S*\.flatten", "optree..flatten", s) + return re.sub(r"qualname: \S*flatten", "qualname: .flatten", s) fn(torch.randn(4)) - self.assertEqual(len(counters["graph_break"]), 1) + self.assertGreaterEqual(len(counters["graph_break"]), 1) first_graph_break = next(iter(counters["graph_break"].keys())) self.assertExpectedInline( post_munge(first_graph_break), """\ Attempted to call function marked as skipped - Explanation: Dynamo cannot trace optree C/C++ function optree..flatten_with_path. + Explanation: Dynamo cannot trace optree C/C++ function optree..flatten. Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py - Developer debug context: module: optree._C, qualname: .flatten_with_path, skip reason: + Developer debug context: module: optree._C, qualname: .flatten, skip reason: For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index c01026ef30211..c4ac9138a6290 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -11,6 +11,16 @@ import optree import optree._C import optree.utils +from optree import ( + is_namedtuple, + is_namedtuple_class, + is_namedtuple_instance, + is_structseq, + is_structseq_class, + is_structseq_instance, + namedtuple_fields, + structseq_fields, +) import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES @@ -26,7 +36,30 @@ from torch.utils._cxx_pytree import PyTree -__all__: list[str] = [] +__all__ = [ + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", + "namedtuple_fields", + "structseq_fields", + "treespec_leaf", + "treespec_tuple", + "treespec_dict", + "tree_is_leaf", + "tree_iter", + "tree_leaves", + "tree_flatten", + "tree_flatten_with_path", + "tree_structure", + "tree_unflatten", + "tree_map", + "tree_map_", + "tree_map_with_path", + "tree_map_with_path_", +] _T = TypeVar("_T") @@ -48,21 +81,20 @@ def _(*args: Any, **kwargs: Any) -> bool: __name = "" -for __name in ( - "is_namedtuple", - "is_namedtuple_class", - "is_namedtuple_instance", - "is_structseq", - "is_structseq_class", - "is_structseq_instance", - "namedtuple_fields", - "structseq_fields", +for __name, __func in ( + ("is_namedtuple", is_namedtuple), + ("is_namedtuple_class", is_namedtuple_class), + ("is_namedtuple_instance", is_namedtuple_instance), + ("is_structseq", is_structseq), + ("is_structseq_class", is_structseq_class), + ("is_structseq_instance", is_structseq_instance), + ("namedtuple_fields", namedtuple_fields), + ("structseq_fields", structseq_fields), ): - __func = getattr(optree, __name) - globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( - __func.__python_implementation__ - ) - __all__ += [__name] # noqa: PLE0604 + globals()[__name] = substitute_in_graph( + __func, # type: ignore[arg-type] + can_constant_fold_through=True, + )(__func.__python_implementation__) # type: ignore[attr-defined] del __func del __name @@ -78,7 +110,7 @@ def tree_is_leaf( ) -> bool: if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): return True - if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined] + if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: return True return False @@ -113,9 +145,6 @@ def tree_iter( stack.extend(reversed(children)) -__all__ += ["tree_iter"] - - @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] def tree_leaves( tree: PyTree, @@ -135,9 +164,6 @@ def tree_leaves( ) -__all__ += ["tree_leaves"] - - class _Asterisk(str): __slots__ = () @@ -168,7 +194,7 @@ class PyTreeSpec: num_leaves: int = field(init=False) num_children: int = field(init=False) - def __post_init__(self) -> None: + def __post_init__(self, /) -> None: if self._type is None: assert len(self._children) == 0 assert self._metadata is None @@ -187,7 +213,7 @@ def __post_init__(self) -> None: object.__setattr__(self, "num_leaves", num_leaves) object.__setattr__(self, "num_children", num_children) - def __repr__(self) -> str: + def __repr__(self, /) -> str: def helper(treespec: PyTreeSpec) -> str: if treespec.is_leaf(): assert treespec.type is None @@ -221,29 +247,78 @@ def helper(treespec: PyTreeSpec) -> str: ] return f"PyTreeSpec({', '.join(inner)})" - def __len__(self) -> int: + def __len__(self, /) -> int: return self.num_leaves @property - def type(self) -> builtins.type | None: + def type(self, /) -> builtins.type | None: return self._type - def is_leaf(self) -> bool: + def is_leaf(self, /) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 - def children(self) -> list[PyTreeSpec]: + def paths(self, /) -> list[tuple[Any, ...]]: + def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None: + if treespec.is_leaf(): + paths.append(path_prefix) + return + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper(subspec, path_prefix + [entry]) + + paths: list[list[Any]] = [] + helper(self, []) + return [tuple(path) for path in paths] + + def accessors(self, /) -> list[optree.PyTreeAccessor]: + def helper( + treespec: PyTreeSpec, + entry_path_prefix: list[optree.PyTreeEntry], + ) -> None: + if treespec.is_leaf(): + entry_paths.append(entry_path_prefix) + return + + node_type = treespec.type + assert node_type is not None + handler = optree.register_pytree_node.get( + node_type, namespace=treespec.namespace + ) + assert handler is not None + kind: optree.PyTreeKind = handler.kind + path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type + + for entry, subspec in zip( + treespec._entries, + treespec._children, + strict=True, + ): + helper( + subspec, + entry_path_prefix + [path_entry_type(entry, node_type, kind)], + ) + + entry_paths: list[list[optree.PyTreeEntry]] = [] + helper(self, []) + return [optree.PyTreeAccessor(path) for path in entry_paths] + + def children(self, /) -> list[PyTreeSpec]: return list(self._children) - def child(self, index: int) -> PyTreeSpec: + def child(self, index: int, /) -> PyTreeSpec: return self._children[index] - def entries(self) -> list[Any]: + def entries(self, /) -> list[Any]: return list(self._entries) - def entry(self, index: int) -> Any: + def entry(self, index: int, /) -> Any: return self._entries[index] - def flatten_up_to(self, tree: PyTree) -> list[PyTree]: + def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]: def helper( treespec: PyTreeSpec, node: PyTree, @@ -324,14 +399,14 @@ def helper( f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch ) - for subtree, subspec in zip(children, treespec._children): + for subtree, subspec in zip(children, treespec._children, strict=True): helper(subspec, subtree, subtrees) subtrees: list[PyTree] = [] helper(self, tree, subtrees) return subtrees - def unflatten(self, leaves: Iterable[Any]) -> PyTree: + def unflatten(self, leaves: Iterable[Any], /) -> PyTree: if not isinstance(leaves, (list, tuple)): leaves = list(leaves) if len(leaves) != self.num_leaves: @@ -408,7 +483,7 @@ def treespec_tuple( "All children PyTreeSpecs must have the same `namespace` value " f"as the parent; expected {namespace!r}, got: {children!r}.", ) - handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined] + handler = optree.register_pytree_node.get(tuple, namespace=namespace) assert handler is not None return PyTreeSpec( tuple(children), @@ -531,7 +606,27 @@ def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: return leaves, treespec -__all__ += ["tree_flatten"] +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_flatten_with_path, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_flatten_with_path( + tree: PyTree, + /, + is_leaf: Callable[[PyTree], bool] | None = None, + *, + none_is_leaf: bool = False, + namespace: str = "", +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + leaves, treespec = tree_flatten( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return treespec.paths(), leaves, treespec # type: ignore[return-value] @substitute_in_graph( # type: ignore[arg-type] @@ -556,9 +651,6 @@ def tree_structure( )[1] -__all__ += ["tree_structure"] - - @substitute_in_graph( # type: ignore[arg-type] optree.tree_unflatten, # We need to disable constant folding here because we want the function to reference the @@ -574,9 +666,6 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: return treespec.unflatten(leaves) -__all__ += ["tree_unflatten"] - - @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map( func: Callable[..., Any], @@ -597,9 +686,6 @@ def tree_map( return treespec.unflatten(map(func, *flat_args)) -__all__ += ["tree_map"] - - @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map_( func: Callable[..., Any], @@ -621,7 +707,47 @@ def tree_map_( return tree -__all__ += ["tree_map_"] +@substitute_in_graph(optree.tree_map_with_path, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + /, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTree: + paths, leaves, treespec = tree_flatten_with_path( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, paths, *flat_args)) + + +@substitute_in_graph(optree.tree_map_with_path_, can_constant_fold_through=True) # type: ignore[arg-type] +def tree_map_with_path_( + func: Callable[..., Any], + tree: PyTree, + /, + *rests: PyTree, + is_leaf: Callable[[PyTree], bool] | None = None, + none_is_leaf: bool = False, + namespace: str = "", +) -> PyTree: + paths, leaves, treespec = tree_flatten_with_path( + tree, + is_leaf=is_leaf, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + # consume and exhaust the iterable + deque(map(func, paths, *flat_args), maxlen=0) + return tree + _none_registration = optree.register_pytree_node.get(type(None)) assert _none_registration is not None @@ -669,5 +795,5 @@ def dict_unflatten( ) -> dict[_KT, _VT]: original_keys, sorted_keys = metadata d = dict.fromkeys(original_keys) - d.update(zip(sorted_keys, values)) + d.update(zip(sorted_keys, values, strict=True)) return d # type: ignore[return-value] From 713e289ae7c68bb4406da9bdd224f463851ca426 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 7 Nov 2025 23:51:26 +0800 Subject: [PATCH 212/651] [dynamo][pytree] support more `optree` functions by polyfill the underlying CXX functions directly (#167292) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167292 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167221, #167211 --- test/dynamo/test_error_messages.py | 28 +++++-- torch/_dynamo/polyfills/pytree.py | 129 ++++++++++------------------- 2 files changed, 62 insertions(+), 95 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 1cf841c4947e1..995c733716f1b 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -422,29 +422,41 @@ def test_optree_graph_break_message(self): import optree @torch.compile(backend="eager") - def fn(x): + def fn1(x): tree = {"a": x, "b": (x - 1, 2 * x)} sin, cos = optree.tree_transpose_map( - lambda x: (torch.sin(x), torch.cos(x)), + lambda t: (torch.sin(t), torch.cos(t)), tree, ) return sin, cos - def post_munge(s): - s = re.sub(r"optree\.\S*\.flatten", "optree..flatten", s) - return re.sub(r"qualname: \S*flatten", "qualname: .flatten", s) + fn1(torch.randn(4)) + self.assertEqual(len(counters["graph_break"]), 0) + + @torch.compile(backend="eager") + def fn2(x): + spec = optree.treespec_deque([]) + return spec, x - fn(torch.randn(4)) + fn2(torch.randn(4)) self.assertGreaterEqual(len(counters["graph_break"]), 1) first_graph_break = next(iter(counters["graph_break"].keys())) + + def post_munge(string): + return re.sub( + r"(optree\.|qualname: )\S*(\.make_from_collection)", + r"\1\2", + string, + ) + self.assertExpectedInline( post_munge(first_graph_break), """\ Attempted to call function marked as skipped - Explanation: Dynamo cannot trace optree C/C++ function optree..flatten. + Explanation: Dynamo cannot trace optree C/C++ function optree..make_from_collection. Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py - Developer debug context: module: optree._C, qualname: .flatten, skip reason: + Developer debug context: module: optree._C, qualname: .make_from_collection, skip reason: For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index c4ac9138a6290..1c6283e8a038f 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -55,10 +55,6 @@ "tree_flatten_with_path", "tree_structure", "tree_unflatten", - "tree_map", - "tree_map_", - "tree_map_with_path", - "tree_map_with_path_", ] @@ -607,146 +603,105 @@ def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: @substitute_in_graph( # type: ignore[arg-type] - optree.tree_flatten_with_path, + optree._C.flatten, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) -def tree_flatten_with_path( +def _C_flatten( tree: PyTree, /, - is_leaf: Callable[[PyTree], bool] | None = None, - *, + leaf_predicate: Callable[[PyTree], bool] | None = None, none_is_leaf: bool = False, namespace: str = "", -) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: - leaves, treespec = tree_flatten( +) -> tuple[list[Any], PyTreeSpec]: + return tree_flatten( # type: ignore[return-value] tree, - is_leaf=is_leaf, + is_leaf=leaf_predicate, none_is_leaf=none_is_leaf, namespace=namespace, ) - return treespec.paths(), leaves, treespec # type: ignore[return-value] @substitute_in_graph( # type: ignore[arg-type] - optree.tree_structure, + optree.tree_flatten_with_path, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) -def tree_structure( +def tree_flatten_with_path( tree: PyTree, /, is_leaf: Callable[[PyTree], bool] | None = None, *, none_is_leaf: bool = False, namespace: str = "", -) -> PyTreeSpec: - return tree_flatten( # type: ignore[return-value] +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + leaves, treespec = tree_flatten( tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, - )[1] + ) + return treespec.paths(), leaves, treespec # type: ignore[return-value] @substitute_in_graph( # type: ignore[arg-type] - optree.tree_unflatten, + optree._C.flatten_with_path, # We need to disable constant folding here because we want the function to reference the # PyTreeSpec class defined above, not the one in the C++ module. can_constant_fold_through=False, ) -def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: - if not _is_pytreespec_instance(treespec): - raise TypeError( - f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " - f"PyTreeSpec but got item of type {type(treespec)}." - ) - return treespec.unflatten(leaves) - - -@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] -def tree_map( - func: Callable[..., Any], +def _C_flatten_with_path( tree: PyTree, /, - *rests: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, + leaf_predicate: Callable[[PyTree], bool] | None = None, none_is_leaf: bool = False, namespace: str = "", -) -> PyTree: - leaves, treespec = tree_flatten( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - return treespec.unflatten(map(func, *flat_args)) - - -@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] -def tree_map_( - func: Callable[..., Any], - tree: PyTree, - /, - *rests: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - none_is_leaf: bool = False, - namespace: str = "", -) -> PyTree: - leaves, treespec = tree_flatten( +) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]: + return tree_flatten_with_path( # type: ignore[return-value] tree, - is_leaf=is_leaf, + is_leaf=leaf_predicate, none_is_leaf=none_is_leaf, namespace=namespace, ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable - return tree -@substitute_in_graph(optree.tree_map_with_path, can_constant_fold_through=True) # type: ignore[arg-type] -def tree_map_with_path( - func: Callable[..., Any], +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_structure, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_structure( tree: PyTree, /, - *rests: PyTree, is_leaf: Callable[[PyTree], bool] | None = None, + *, none_is_leaf: bool = False, namespace: str = "", -) -> PyTree: - paths, leaves, treespec = tree_flatten_with_path( +) -> PyTreeSpec: + return tree_flatten( # type: ignore[return-value] tree, is_leaf=is_leaf, none_is_leaf=none_is_leaf, namespace=namespace, - ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - return treespec.unflatten(map(func, paths, *flat_args)) + )[1] -@substitute_in_graph(optree.tree_map_with_path_, can_constant_fold_through=True) # type: ignore[arg-type] -def tree_map_with_path_( - func: Callable[..., Any], - tree: PyTree, - /, - *rests: PyTree, - is_leaf: Callable[[PyTree], bool] | None = None, - none_is_leaf: bool = False, - namespace: str = "", -) -> PyTree: - paths, leaves, treespec = tree_flatten_with_path( - tree, - is_leaf=is_leaf, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - # consume and exhaust the iterable - deque(map(func, paths, *flat_args), maxlen=0) - return tree +@substitute_in_graph( # type: ignore[arg-type] + optree.tree_unflatten, + # We need to disable constant folding here because we want the function to reference the + # PyTreeSpec class defined above, not the one in the C++ module. + can_constant_fold_through=False, +) +def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: + if not _is_pytreespec_instance(treespec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be instance of " + f"PyTreeSpec but got item of type {type(treespec)}." + ) + return treespec.unflatten(leaves) _none_registration = optree.register_pytree_node.get(type(None)) From c62a17a2fb8880fbb7c9e8c198a532f8360a51c3 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 7 Nov 2025 18:09:37 +0000 Subject: [PATCH 213/651] [ez] Remove some unused vars in common_utils.py (#166453) I can't find where these are used Pull Request resolved: https://github.com/pytorch/pytorch/pull/166453 Approved by: https://github.com/malfet --- torch/testing/_internal/common_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 8f4f8efc4108e..d5afc413daed8 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -114,8 +114,6 @@ class ProfilingMode(Enum): PROFILING = 3 # Set by parse_cmd_line_args() if called -CI_FUNCTORCH_ROOT = "" -CI_PT_ROOT = "" CI_TEST_PREFIX = "" DISABLED_TESTS_FILE = "" GRAPH_EXECUTOR : Optional[ProfilingMode] = None @@ -959,8 +957,6 @@ def _get_test_report_path(): return os.path.join('test-reports', test_source) def parse_cmd_line_args(): - global CI_FUNCTORCH_ROOT - global CI_PT_ROOT global CI_TEST_PREFIX global DISABLED_TESTS_FILE global GRAPH_EXECUTOR @@ -1039,10 +1035,8 @@ def run_unittest_help(argv): set_rng_seed() -# CI Prefix path used only on CI environment + # CI Prefix path used only on CI environment CI_TEST_PREFIX = str(Path(os.getcwd())) - CI_PT_ROOT = str(Path(os.getcwd()).parent) - CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: From 22650c89fb5670ca9eebbd81fd8a7f57fc3d5303 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Fri, 7 Nov 2025 18:11:45 +0000 Subject: [PATCH 214/651] [ROCm] Update skip_if_lt_x_gpu to work with MultiProcContinuous class (#167281) - Since MultiProcContinuous class spawns one process per GPU and runs UT in each of the processes, we need to ensure we are propagating the exit code associated with skip all the way to the main worker thread that spawned all the child processes. - This commit also updates several UTs that are meant for 4 GPUs but incorrectly calls skip_if_lt_x_gpu with 2 as an input. Examples: - test_replicate_with_fsdp.py - test_dtensor_resharding.py - test_state_dict.py - test_functional_api.py: Fix typo. multi-accelerator doesn't exit, replaced with multi-gpu - test_op_strategy.py: world_size was hardcoded - test_math_ops.py: UT written for 4 GPU, so skipping for anything less - test_schedule_multiproc.py: All UTs in this suite are required to run on 2+ GPUs, therefore, adding skips if less than 4 GPUs are supplied Fixes https://github.com/pytorch/pytorch/issues/166875 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167281 Approved by: https://github.com/jeffdaily --- .../_composable/test_replicate_with_fsdp.py | 10 +++++----- .../checkpoint/test_dtensor_resharding.py | 4 +++- test/distributed/checkpoint/test_state_dict.py | 2 +- .../pipelining/test_schedule_multiproc.py | 16 ++++++++++++++++ test/distributed/tensor/test_math_ops.py | 2 ++ test/distributed/tensor/test_op_strategy.py | 2 +- test/distributed/test_functional_api.py | 2 +- torch/testing/_internal/common_distributed.py | 18 ++++++++++++++++++ 8 files changed, 47 insertions(+), 9 deletions(-) diff --git a/test/distributed/_composable/test_replicate_with_fsdp.py b/test/distributed/_composable/test_replicate_with_fsdp.py index 7ec059a647ee5..1087d9c813c9e 100644 --- a/test/distributed/_composable/test_replicate_with_fsdp.py +++ b/test/distributed/_composable/test_replicate_with_fsdp.py @@ -76,7 +76,7 @@ def _init_pg(self): store=dist.FileStore(self.file_name, self.world_size), ) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_replicate_transformer(self): """ This tests that replicate works on a transformer model with fully_shard and replicate layers @@ -126,7 +126,7 @@ def _test_replicate_transformer(self, sharding_strategy): for parameter in layer.parameters(): self.assertEqual(parameter.placements, (Shard(dim=0),)) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_replicate_transformer_managed_modules(self): """ This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers, @@ -178,7 +178,7 @@ def test_replicate_transformer_managed_modules(self): replicate_model = replicate(replicate_model) self.assertEqual(len(_get_managed_modules((replicate_model,))), 21) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_replicate_tp_device_mesh(self): """ This tests that a user can pass in a device mesh to replicate a module @@ -206,7 +206,7 @@ def test_replicate_tp_device_mesh(self): self.assertEqual(parameter.device_mesh.shape, (2,)) self.assertEqual(parameter.placements, (Replicate(),)) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_train_replicate_fsdp(self): """ Tests that replicate_model has the same behavior as original model when training @@ -253,7 +253,7 @@ def test_train_replicate_fsdp(self): self.assertEqual(replicate_loss, loss) check_sharded_parity(self, model, replicate_model) - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_train_parity_2d_mlp(self): """ Verifies when a device mesh is passed in, the model has the same behavior as the original model when training diff --git a/test/distributed/checkpoint/test_dtensor_resharding.py b/test/distributed/checkpoint/test_dtensor_resharding.py index 306f61a597c25..233fb3e7e0f03 100644 --- a/test/distributed/checkpoint/test_dtensor_resharding.py +++ b/test/distributed/checkpoint/test_dtensor_resharding.py @@ -299,7 +299,7 @@ def test_dtensor_checkpoint_resharding_with_empty_shard(self): @with_comms @with_temp_dir - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_dtensor_checkpoint_with_uneven_shards(self) -> None: """ Saving a dtensor with uneven shards. @@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase): @with_comms @with_temp_dir + @skip_if_lt_x_gpu(4) def test_uneven_reshard_with_checkpointable_api(self) -> None: """ Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API. @@ -498,6 +499,7 @@ def test_uneven_reshard_with_checkpointable_api(self) -> None: @with_comms @with_temp_dir + @skip_if_lt_x_gpu(4) def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None: """ Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API. diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 095dc4bc3514a..1206f13213108 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -886,7 +886,7 @@ def test_setting_meta_device_model(self) -> None: self.assertEqual(cpu_model_value, meta_model_value) @with_comms - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) def test_setting_meta_device_model_broadcasting_and_memory(self) -> None: # This test verifies that we can set model state dict by a meta device model # With the correlated changes in state_dict, meta device model should be accepted diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index 5538e750d27eb..687cb113b48bf 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -39,6 +39,7 @@ from torch.testing._internal.common_distributed import ( MultiProcContinuousTest, requires_accelerator_dist_backend, + skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( check_leaked_tensors, @@ -231,6 +232,7 @@ def config(self) -> PipelineTestConfig: not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [_ScheduleForwardOnly]) + @skip_if_lt_x_gpu(4) def test_forward_only(self, ScheduleClass): mod, mod_ref, x, _, _ = setup_models_and_data(self.config) x_clone = x.clone() @@ -274,6 +276,7 @@ def test_forward_only(self, ScheduleClass): ScheduleInterleavedZeroBubble, ], ) + @skip_if_lt_x_gpu(4) def test_eval_inference_mode(self, ScheduleClass): num_microbatches = 4 if ScheduleClass in [ @@ -351,6 +354,7 @@ def test_eval_inference_mode(self, ScheduleClass): ScheduleInterleavedZeroBubble, ], ) + @skip_if_lt_x_gpu(4) def test_return_output(self, ScheduleClass): num_microbatches = 4 if ScheduleClass in [ @@ -406,6 +410,7 @@ def test_return_output(self, ScheduleClass): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + @skip_if_lt_x_gpu(4) def test_multi_iter(self, ScheduleClass): mod, _, x, target, loss_fn = setup_models_and_data(self.config) chunks = 4 @@ -429,6 +434,7 @@ def test_multi_iter(self, ScheduleClass): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + @skip_if_lt_x_gpu(4) def test_kwargs_with_tracer(self, ScheduleClass): mod = ModelWithKwargs(d_hid, splits=self.world_size) mod.to(self.device) @@ -481,6 +487,7 @@ def test_kwargs_with_tracer(self, ScheduleClass): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) + @skip_if_lt_x_gpu(4) def test_grad_with_tracer(self, ScheduleClass): mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config) @@ -523,6 +530,7 @@ def test_grad_with_tracer(self, ScheduleClass): ) @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) @parametrize("shape_inference", [True, False]) + @skip_if_lt_x_gpu(4) def test_grad_with_manual(self, ScheduleClass, shape_inference): mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config) @@ -586,6 +594,7 @@ def test_grad_with_manual(self, ScheduleClass, shape_inference): ScheduleInterleavedZeroBubble, ], ) + @skip_if_lt_x_gpu(4) def test_grad_with_manual_interleaved(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size @@ -650,6 +659,7 @@ def test_grad_with_manual_interleaved(self, ScheduleClass): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) + @skip_if_lt_x_gpu(4) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size @@ -736,6 +746,7 @@ def dw_runner(): "schedule_class", [ScheduleZBVZeroBubble, ScheduleDualPipeV], ) + @skip_if_lt_x_gpu(4) def test_v_shape_schedules(self, schedule_class): n_stages = 8 rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]} @@ -780,6 +791,7 @@ def test_v_shape_schedules(self, schedule_class): @skip_but_pass_in_sandcastle_if( not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) + @skip_if_lt_x_gpu(4) def test_custom_function_callback(self): """Test the custom function callback functionality with _PipelineScheduleRuntime.""" n_stages = 8 @@ -979,6 +991,7 @@ def overlap_callback(action: _Action, ctx: _PipelineContext): "ScheduleClass", [ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B], ) + @skip_if_lt_x_gpu(4) def test_zero_bubble_with_model_kwargs(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size @@ -1072,6 +1085,7 @@ def config(self) -> PipelineTestConfig: "schedule_class", [ScheduleVShaped, ScheduleUnbalanced], ) + @skip_if_lt_x_gpu(4) def test_non_symmetric_stage_ids(self, schedule_class): n_stages = schedule_class.n_stages rank_stages = schedule_class.rank_stages @@ -1121,6 +1135,7 @@ def test_non_symmetric_stage_ids(self, schedule_class): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleWithReorderedB]) + @skip_if_lt_x_gpu(4) def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): n_stages = 2 stages_per_rank = 1 @@ -1181,6 +1196,7 @@ def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass): not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs" ) @parametrize("ScheduleClass", [ScheduleWithW]) + @skip_if_lt_x_gpu(4) def test_schedule_with_native_zero_bubble(self, ScheduleClass): n_stages = ScheduleClass.n_stages num_microbatches = ScheduleClass.num_microbatches diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index f031085b23bd2..56321806477b9 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -26,6 +26,7 @@ RowwiseParallel, SequenceParallel, ) +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, @@ -764,6 +765,7 @@ def test_foreach_norm_different_mesh(self): self.assertEqual(grad1_norm.device_mesh, mesh_y) @with_comms + @skip_if_lt_x_gpu(4) def test_foreach_add_different_mesh(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( diff --git a/test/distributed/tensor/test_op_strategy.py b/test/distributed/tensor/test_op_strategy.py index da9c4b4174b5d..139f5fb61fac8 100644 --- a/test/distributed/tensor/test_op_strategy.py +++ b/test/distributed/tensor/test_op_strategy.py @@ -577,7 +577,7 @@ def mock_select_func(strategy, op_schema=None): self.assertEqual( comm_mode.get_comm_counts(), { - torch.ops.c10d_functional.all_gather_into_tensor: 4, + torch.ops.c10d_functional.all_gather_into_tensor: self.world_size, }, ) expected_cost = [ diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py index b5522fe2bef06..d4954b3e4f56d 100644 --- a/test/distributed/test_functional_api.py +++ b/test/distributed/test_functional_api.py @@ -485,7 +485,7 @@ def allred_mesh_dim(input): def exit_if_lt_x_accelerators(x): if torch.accelerator.is_available(): if torch.accelerator.device_count() < x: - sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code) + sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) def with_comms(func=None): diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 91f09adf9e816..e93c346a6645d 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1771,6 +1771,22 @@ def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue) cls._run_test_given_id(test_id) completion_queue.put(test_id) except BaseException as ex: # noqa: B036 + if isinstance(ex, SystemExit): + # Get exit code from the process + exit_code = getattr(ex, "code", None) + + # Look up exit code in TEST_SKIPS to see if it is a valid skip + skip_entry = next( + (v for v in TEST_SKIPS.values() if v.exit_code == exit_code), + None, + ) + + # If we found an entry, we want to skip the test and the object back to the main process + if skip_entry: + completion_queue.put(unittest.SkipTest(skip_entry.message)) + # Skip exception handling below, move to main thread for processing the skip + continue + raised_exception = True # Send the exception and stack trace back to the dispatcher exc_info = sys.exc_info() @@ -1892,6 +1908,8 @@ def wrapper(self): # Wait for the workers to finish the test for i, completion_queue in enumerate(self.completion_queues): rv = completion_queue.get() + if isinstance(rv, unittest.SkipTest): + raise rv if isinstance(rv, BaseException): # Hit an exception, re-raise it in the main process. logger.warning( From e401a56b96adc7cf688362fc4a28ebb117dff171 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 7 Nov 2025 18:14:44 +0000 Subject: [PATCH 215/651] [ez] Remove some dead code from test artifact related files (#166966) Remove circle ci path since it's no longer used Remove function that is not used Pull Request resolved: https://github.com/pytorch/pytorch/pull/166966 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- tools/stats/upload_test_stats.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index b2b0869d48350..eb5ac9b6f474c 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -165,21 +165,6 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str, return flattened -def get_tests_for_circleci( - workflow_run_id: int, workflow_run_attempt: int -) -> list[dict[str, Any]]: - # Parse the reports and transform them to JSON - test_cases = [] - for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"): - test_cases.extend( - parse_xml_report( - "testcase", xml_report, workflow_run_id, workflow_run_attempt - ) - ) - - return test_cases - - def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]: """Group test cases by classname, file, and job_id. We perform the aggregation manually instead of using the `test-suite` XML tag because xmlrunner does @@ -258,21 +243,11 @@ def init_value(test_case: dict[str, Any]) -> dict[str, Any]: required=True, help="Head repository of the workflow", ) - parser.add_argument( - "--circleci", - action="store_true", - help="If this is being run through circleci", - ) args = parser.parse_args() print(f"Workflow id is: {args.workflow_run_id}") - if args.circleci: - test_cases = get_tests_for_circleci( - args.workflow_run_id, args.workflow_run_attempt - ) - else: - test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt) + test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt) # Flush stdout so that any errors in the upload show up last in the logs. sys.stdout.flush() From d1446ad75c9d93c079851aca21352517162aee3c Mon Sep 17 00:00:00 2001 From: Tristan Trouwen Date: Fri, 7 Nov 2025 19:31:51 +0000 Subject: [PATCH 216/651] Register floor_divide.out for MTIA (#167280) Differential Revision: D86468749 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167280 Approved by: https://github.com/albanD --- aten/src/ATen/native/native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 4424f51827d45..491521bdc9601 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2803,7 +2803,7 @@ - func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: floor_divide_out + CPU, CUDA, MPS, MTIA: floor_divide_out SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim - func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor From 28615a765d5eecbe08cb005d5e2dc5fce4152a4b Mon Sep 17 00:00:00 2001 From: Kushagra Rastogi Date: Fri, 7 Nov 2025 19:32:40 +0000 Subject: [PATCH 217/651] Fix: list index out of range with softmax when using 0 dim (#166547) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #163971 Problem: PyTorch's inductor compiler crashed with IndexError: list index out of range when compiling code that uses 0-dimensional tensors with operations like torch.softmax(scalar_tensor, dim=0). A 0-dim tensor has shape = torch.Size([]) (empty shape) ``` ndim = 0 (zero dimensions) len(shape) = 0 (no indices to access) # Line 972: Pad other_shape to match inp dimensions other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) # For scalar tensors: # inp_ndim = 0 # as input is scalar # other_shape = [] # Result: [1] * (0 - 0) + [] = [] (still empty!) dim = match.kwargs["dim"] # dim = 0 if isinstance(dim, int): dim = (dim,) # crash is happening here! return all(statically_known_true(other_shape[d] == 1) for d in dim) # ^^^^^^^^^^^^^^^^ # Tries other_shape[0] but other_shape = [] (empty!) # → IndexError: list index out of range ``` The function _other_is_broadcasted_in_dim() is an optimization check for a softmax fusion pattern. It verifies whether it's safe to rewrite: ``` # From scaled = inp * other result = scaled - scaled.amax(dim, keepdim=True) # To this more stable form: result = (inp - inp.amax(dim, keepdim=True)) * other ``` The optimization is only valid if other is constant across the reduction dimension (i.e., broadcasted to size 1 in that dimension). Otherwise, scaling changes which element is the maximum. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166547 Approved by: https://github.com/jansel --- test/inductor/test_cpu_repro.py | 9 +++++++++ torch/_inductor/fx_passes/joint_graph.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 937208d9fd531..cf4900c8536bf 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -3278,6 +3278,15 @@ def fn(x): metrics.reset() self.common(fn, (x,)) + def test_softmax_with_zero_dim(self): + def fn(x): + x = torch.softmax(x, 0) + return x + + x = torch.rand([], dtype=torch.bfloat16) + metrics.reset() + self.common(fn, (x,)) + @config.patch({"fx_graph_cache": False, "fx_graph_remote_cache": False}) def test_local_buffer_in_outer_loop_fusion(self): def fn(x): diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 25b10966cfa96..9db694f1d8629 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -893,6 +893,9 @@ def _other_is_broadcasted_in_dim(match): if isinstance(dim, int): dim = (dim,) + if any(d >= len(other_shape) for d in dim): + return False + return all(statically_known_true(other_shape[d] == 1) for d in dim) From 2f5223564ea9caf5741f01e098722608a85c1fdb Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 7 Nov 2025 19:38:36 +0000 Subject: [PATCH 218/651] [ez] Remove experiment for uploading all test runs (#167133) reverts #165484 after #166988 they are just uploaded while its running Pull Request resolved: https://github.com/pytorch/pytorch/pull/167133 Approved by: https://github.com/malfet --- tools/stats/upload_test_stats.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index eb5ac9b6f474c..45a390fc5051d 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -288,12 +288,4 @@ def init_value(test_case: dict[str, Any]) -> dict[str, Any]: remove_nan_inf(test_cases), ) - # Part of an experiment to see if we can handle all the data as is - upload_workflow_stats_to_s3( - args.workflow_run_id, - args.workflow_run_attempt, - "all_test_runs", - remove_nan_inf(test_cases), - ) - upload_additional_info(args.workflow_run_id, args.workflow_run_attempt, test_cases) From 4c41e9bde7a7f4adbc518a02762fc50039a2d68f Mon Sep 17 00:00:00 2001 From: Dylan Maloy Date: Fri, 7 Nov 2025 19:48:19 +0000 Subject: [PATCH 219/651] making TORCH_CHECK_{COND} non-fatal (#167004) TORCH_CHECK is non-fatal by design, but TORCH_CHECK_{COND} macros are fatal. this is confusing, and we should limit fatality to the set of debug macros. Differential Revision: D86168955 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167004 Approved by: https://github.com/malfet --- c10/test/build.bzl | 9 ++ c10/test/util/nofatal_test.cpp | 53 +++++++++++ c10/util/Exception.h | 92 +++++++++++++++++++ c10/util/Logging.cpp | 53 +++++++++-- c10/util/logging_common.h | 74 +++++++++++++++ c10/util/logging_is_google_glog.h | 94 ++++++++++--------- c10/util/logging_is_not_google_glog.h | 125 +++++--------------------- 7 files changed, 341 insertions(+), 159 deletions(-) create mode 100644 c10/test/util/nofatal_test.cpp create mode 100644 c10/util/logging_common.h diff --git a/c10/test/build.bzl b/c10/test/build.bzl index deb917dd8fcf3..7b4028ab4afed 100644 --- a/c10/test/build.bzl +++ b/c10/test/build.bzl @@ -66,6 +66,15 @@ def define_targets(rules): ], ) + rules.cc_test( + name = "util/nofatal_test", + srcs = ["util/nofatal_test.cpp"], + deps = [ + "//c10/util:base", + "@com_google_googletest//:gtest_main", + ], + ) + rules.cc_test( name = "util/ssize_test", srcs = ["util/ssize_test.cpp"], diff --git a/c10/test/util/nofatal_test.cpp b/c10/test/util/nofatal_test.cpp new file mode 100644 index 0000000000000..ba4b40b6f917e --- /dev/null +++ b/c10/test/util/nofatal_test.cpp @@ -0,0 +1,53 @@ +#include + +#include +#include + +namespace { +template +inline void expectThrowsEq(T&& fn, const char* expected_msg) { + try { + std::forward(fn)(); + } catch (const c10::Error& e) { + EXPECT_TRUE( + std::string(e.what_without_backtrace()).find(expected_msg) != + std::string::npos); + return; + } + ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg + << "\" but didn't throw"; +} +} // namespace + +TEST(NofatalTest, TorchCheckComparisons) { + // quick make sure that no-op works as expected + TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1; + expectThrowsEq( + []() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; }, + "Check failed: 1 == 2 (1 vs. 2). i am a silly message 1"); + expectThrowsEq( + []() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2)."); + expectThrowsEq( + []() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2)."); + expectThrowsEq( + []() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2)."); + expectThrowsEq( + []() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2)."); + expectThrowsEq( + []() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3)."); + expectThrowsEq( + []() { + void* p = nullptr; + TORCH_CHECK_NOTNULL(p); + }, + "Check failed: 'p' must be non NULL."); + +#if GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + // if dbg build, DCHECK should result in deth + EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed"); +#else + TORCH_DCHECK_EQ(1, 2); // no-op +#endif +#endif // GTEST_HAS_DEATH_TEST +} diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 6b2fd626bfb5e..28a2ee06ecd3e 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -702,6 +702,98 @@ namespace c10::detail { #define TORCH_CHECK_ARG(cond, argN, ...) \ TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__) +#ifndef FATAL_IF +#ifdef C10_USE_GLOG +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \ + .stream() +#else +#define FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream() +#endif +#endif + +#ifndef NON_FATAL_IF +#ifdef C10_USE_GLOG +#define NON_FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger( \ + __FILE__, __LINE__, ::google::GLOG_FATAL, false) \ + .stream() +#else +#define NON_FATAL_IF(condition) \ + condition ? (void)0 \ + : ::c10::LoggerVoidify() & \ + ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \ + .stream() +#endif +#endif + +// Binary comparison check macros +#define TORCH_CHECK_OP(val1, val2, op) \ + NON_FATAL_IF(((val1)op(val2))) \ + << "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \ + << (val2) << "). " + +#define TORCH_DCHECK_OP(val1, val2, op) \ + FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \ + << (val1) << " vs. " << (val2) << "). " + +#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) +#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) +#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) +#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) +#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) +#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) + +// Debug versions of TORCH_CHECK_OP macros +#ifndef NDEBUG +#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >) +#else // !NDEBUG +// Optimized versions - generate no code +#define TORCH_DCHECK_EQ(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, ==) +#define TORCH_DCHECK_NE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, !=) +#define TORCH_DCHECK_LE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, <=) +#define TORCH_DCHECK_LT(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, <) +#define TORCH_DCHECK_GE(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, >=) +#define TORCH_DCHECK_GT(val1, val2) \ + while (false) \ + TORCH_DCHECK_OP(val1, val2, >) +#endif // NDEBUG + +// Null pointer check macro +#define TORCH_CHECK_NOTNULL(val) \ + ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false) + +#ifndef NDEBUG +#define TORCH_DCHECK_NOTNULL(val) \ + ::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true) +#else // !NDEBUG +#define TORCH_DCHECK_NOTNULL(val) \ + while (false) \ + TORCH_CHECK_NOTNULL(val) +#endif // NDEBUG + // ---------------------------------------------------------------------------- // Deprecated macros // ---------------------------------------------------------------------------- diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index 555ab685c0b5f..4bf96b1b6808a 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -291,6 +291,32 @@ namespace c10 { using fLB::FLAGS_logtostderr; using fLI::FLAGS_minloglevel; using fLI::FLAGS_v; + +MessageLogger::MessageLogger( + const char* file, + int line, + int severity, + bool exit_on_fatal) + : stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {} + +MessageLogger::~MessageLogger() noexcept(false) { + if (severity_ == ::google::GLOG_FATAL) { + DealWithFatal(); + } +} + +std::stringstream& MessageLogger::stream() { + return stream_; +} + +void MessageLogger::DealWithFatal() { + if (exit_on_fatal_) { + LOG(FATAL) << stream_.str(); + } else { + throw c10::Error(stream_.str(), nullptr, nullptr); + } +} + } // namespace c10 C10_DEFINE_int( @@ -412,17 +438,16 @@ void ShowLogInfoToStderr() { FLAGS_caffe2_log_level = GLOG_INFO; } -MessageLogger::MessageLogger(const char* file, int line, int severity) - : severity_(severity) { +MessageLogger::MessageLogger( + const char* file, + int line, + int severity, + bool exit_on_fatal) + : severity_(severity), exit_on_fatal_(exit_on_fatal) { if (severity_ < FLAGS_caffe2_log_level) { // Nothing needs to be logged. return; } -#ifdef ANDROID - tag_ = "native"; -#else // !ANDROID - tag_ = ""; -#endif // ANDROID time_t rawtime = 0; time(&rawtime); @@ -458,7 +483,7 @@ MessageLogger::MessageLogger(const char* file, int line, int severity) } // Output the contents of the stream to the proper channel on destruction. -MessageLogger::~MessageLogger() { +MessageLogger::~MessageLogger() noexcept(false) { if (severity_ < FLAGS_caffe2_log_level) { // Nothing needs to be logged. return; @@ -498,6 +523,18 @@ MessageLogger::~MessageLogger() { } } +std::stringstream& MessageLogger::stream() { + return stream_; +} + +void MessageLogger::DealWithFatal() { + if (exit_on_fatal_) { + abort(); + } else { + throw c10::Error(stream_.str(), nullptr, nullptr); + } +} + } // namespace c10 #endif // !C10_USE_GLOG diff --git a/c10/util/logging_common.h b/c10/util/logging_common.h new file mode 100644 index 0000000000000..df65da21c2b22 --- /dev/null +++ b/c10/util/logging_common.h @@ -0,0 +1,74 @@ +#ifndef C10_UTIL_LOGGING_COMMON_H_ +#define C10_UTIL_LOGGING_COMMON_H_ + +#include +#include + +namespace c10 { + +// MessageLogger that throws exceptions instead of aborting (glog version) +// or logs and may abort (non-glog version). +class C10_API MessageLogger { + public: + MessageLogger( + const char* file, + int line, + int severity, + bool exit_on_fatal = true); + ~MessageLogger() noexcept(false); + + // Return the stream associated with the logger object. + std::stringstream& stream(); + + private: + // When there is a fatal log, and fatal == true, we abort + // otherwise, we throw. + void DealWithFatal(); + +#if defined(ANDROID) && !defined(C10_USE_GLOG) + const char* tag_{"native"}; +#endif + std::stringstream stream_; + int severity_; + bool exit_on_fatal_; +}; + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". +class C10_API LoggerVoidify { + public: + LoggerVoidify() = default; + // This has to be an operator with a precedence lower than << but + // higher than ?: + void operator&(const std::ostream& s [[maybe_unused]]) {} +}; + +// Forward declarations for CheckNotNull functions +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal = true); + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal = true); + +} // namespace c10 + +#endif // C10_UTIL_LOGGING_COMMON_H_ diff --git a/c10/util/logging_is_google_glog.h b/c10/util/logging_is_google_glog.h index e5470d22cecd3..f4e2ff979088f 100644 --- a/c10/util/logging_is_google_glog.h +++ b/c10/util/logging_is_google_glog.h @@ -47,57 +47,53 @@ INSTANTIATE_FOR_CONTAINER(set) #endif +#include #include -// Additional macros on top of glog -#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2) -#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2) -#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2) -#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2) -#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2) -#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2) - -#ifndef NDEBUG -#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2) -#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2) -#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2) -#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2) -#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2) -#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2) -#else // !NDEBUG -// These versions generate no code in optimized mode. -#define TORCH_DCHECK_EQ(val1, val2) \ - while (false) \ - DCHECK_EQ(val1, val2) -#define TORCH_DCHECK_NE(val1, val2) \ - while (false) \ - DCHECK_NE(val1, val2) -#define TORCH_DCHECK_LE(val1, val2) \ - while (false) \ - DCHECK_LE(val1, val2) -#define TORCH_DCHECK_LT(val1, val2) \ - while (false) \ - DCHECK_LT(val1, val2) -#define TORCH_DCHECK_GE(val1, val2) \ - while (false) \ - DCHECK_GE(val1, val2) -#define TORCH_DCHECK_GT(val1, val2) \ - while (false) \ - DCHECK_GT(val1, val2) -#endif // NDEBUG - -// Check that a pointer is not null. -#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val) - -#ifndef NDEBUG -// Debug only version of TORCH_CHECK_NOTNULL -#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val) -#else // !NDEBUG -// Optimized version - generates no code. -#define TORCH_DCHECK_NOTNULL(val) \ - while (false) \ - DCHECK_NOTNULL(val) -#endif // NDEBUG +namespace c10 { + +[[noreturn]] void ThrowEnforceNotMet( + const char* file, + const int line, + const char* condition, + const std::string& msg, + const void* caller); + +template +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + if (t == nullptr) { + MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; + } + return t; +} + +template +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +template +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); +} + +} // namespace c10 // Log with source location information override (to be used in generic // warning/error handlers implemented as functions, not macros) diff --git a/c10/util/logging_is_not_google_glog.h b/c10/util/logging_is_not_google_glog.h index 803a833c3cae4..b921cbff47d46 100644 --- a/c10/util/logging_is_not_google_glog.h +++ b/c10/util/logging_is_not_google_glog.h @@ -13,6 +13,7 @@ #include #include +#include const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV"; @@ -24,61 +25,40 @@ const int GLOG_ERROR = 2; const int GLOG_WARNING = 1; const int GLOG_INFO = 0; -class C10_API MessageLogger { - public: - MessageLogger(const char* file, int line, int severity); - ~MessageLogger(); - // Return the stream associated with the logger object. - std::stringstream& stream() { - return stream_; - } - - private: - // When there is a fatal log, we simply abort. - void DealWithFatal() { - abort(); - } - - const char* tag_; - std::stringstream stream_; - int severity_; -}; - -// This class is used to explicitly ignore values in the conditional -// logging macros. This avoids compiler warnings like "value computed -// is not used" and "statement has no effect". -class C10_API LoggerVoidify { - public: - LoggerVoidify() = default; - // This has to be an operator with a precedence lower than << but - // higher than ?: - void operator&(const std::ostream& s [[maybe_unused]]) {} -}; - -// Log a message and terminate. -template -void LogMessageFatal(const char* file, int line, const T& message) { - MessageLogger(file, line, GLOG_FATAL).stream() << message; -} - // Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw // pointers and smart pointers. template -T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) { +T& CheckNotNullCommon( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { if (t == nullptr) { - LogMessageFatal(file, line, std::string(names)); + MessageLogger(file, line, GLOG_FATAL, fatal).stream() + << "Check failed: '" << names << "' must be non NULL. "; } return t; } template -T* CheckNotNull(const char* file, int line, const char* names, T* t) { - return CheckNotNullCommon(file, line, names, t); +T* CheckNotNull( + const char* file, + int line, + const char* names, + T* t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); } template -T& CheckNotNull(const char* file, int line, const char* names, T& t) { - return CheckNotNullCommon(file, line, names, t); +T& CheckNotNull( + const char* file, + int line, + const char* names, + T& t, + bool fatal) { + return CheckNotNullCommon(file, line, names, t, fatal); } } // namespace c10 @@ -136,65 +116,6 @@ static_assert( ::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream() #endif // NDEBUG -#define TORCH_CHECK_OP(val1, val2, op) \ - FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \ - << (val1) << " vs. " << (val2) << ") " - -// TORCH_CHECK_OP macro definitions -#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) -#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) -#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) -#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) -#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) -#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) - -#ifndef NDEBUG -// Debug only versions of TORCH_CHECK_OP macros. -#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==) -#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=) -#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=) -#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <) -#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=) -#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >) -#else // !NDEBUG -// These versions generate no code in optimized mode. -#define TORCH_DCHECK_EQ(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, ==) -#define TORCH_DCHECK_NE(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, !=) -#define TORCH_DCHECK_LE(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, <=) -#define TORCH_DCHECK_LT(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, <) -#define TORCH_DCHECK_GE(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, >=) -#define TORCH_DCHECK_GT(val1, val2) \ - while (false) \ - TORCH_CHECK_OP(val1, val2, >) -#endif // NDEBUG - -// Check that a pointer is not null. -#define TORCH_CHECK_NOTNULL(val) \ - ::c10::CheckNotNull( \ - __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) - -#ifndef NDEBUG -// Debug only version of TORCH_CHECK_NOTNULL -#define TORCH_DCHECK_NOTNULL(val) \ - ::c10::CheckNotNull( \ - __FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val)) -#else // !NDEBUG -// Optimized version - generates no code. -#define TORCH_DCHECK_NOTNULL(val) \ - while (false) \ - TORCH_CHECK_NOTNULL(val) -#endif // NDEBUG - // ---------------------- Support for std objects -------------------------- // These are adapted from glog to support a limited set of logging capability // for STL objects. From c20308b79e314f9d0bc47517d31403224eabe519 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Fri, 7 Nov 2025 20:05:10 +0000 Subject: [PATCH 220/651] [Test CI] Bump ruff to 0.14.4 (#167286) This PR bumps ruff to 0.14.4. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167286 Approved by: https://github.com/janeyx99, https://github.com/Skylion007 --- .lintrunner.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index c7e3797c9b80c..d8bdcc2eefd1b 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1402,7 +1402,7 @@ init_command = [ '--dry-run={{DRYRUN}}', 'usort==1.0.8.post1', 'isort==6.0.1', - 'ruff==0.13.1', # sync with RUFF + 'ruff==0.14.4', # sync with RUFF ] is_formatter = true @@ -1537,7 +1537,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.13.1', # sync with PYFMT + 'ruff==0.14.4', # sync with PYFMT ] is_formatter = true From 289b47e6577754c33e3decbe1570d5d7d99a2e9c Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Fri, 7 Nov 2025 20:05:41 +0000 Subject: [PATCH 221/651] [MPS] empty matrix x vec mul fix (#166561) Fixes empty matrix x vector. Discovered when implementing an op for sparse tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/166561 Approved by: https://github.com/eqy, https://github.com/albanD --- aten/src/ATen/native/mps/operations/Blas.mm | 3 +++ test/test_mps.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/aten/src/ATen/native/mps/operations/Blas.mm b/aten/src/ATen/native/mps/operations/Blas.mm index 16d744cedb8ef..5ebf5f604bfc1 100644 --- a/aten/src/ATen/native/mps/operations/Blas.mm +++ b/aten/src/ATen/native/mps/operations/Blas.mm @@ -141,6 +141,9 @@ Tensor dot_mps(const Tensor& self, const Tensor& other) { }; MPSStream* stream = at::mps::getCurrentMPSStream(); + if (result.numel() == 0) { + return result; + } Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1); @autoreleasepool { diff --git a/test/test_mps.py b/test/test_mps.py index 765ec3c52e036..ca95839e7a7fb 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -647,6 +647,11 @@ def test_large_matmul(self): self.assertEqual(matmul_cpu, matmul_mps.to("cpu")) + def test_empty_matmul_vec(self): + tensor_1 = torch.rand((0, 100), device="mps") + tensor_2 = torch.rand((100, ), device="mps") + self.assertEqual((tensor_1 @ tensor_2).cpu(), tensor_1.cpu() @ tensor_2.cpu()) + class MPSLeakyReluTest(TestCaseMPS): def _npLeakyRelu(self, np_features, negative_slope=0.1): return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype) From b83a3f6e87b9c78655e7080b708acc2be490c1fc Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 7 Nov 2025 09:25:22 -0800 Subject: [PATCH 222/651] compile time comm benchmarking (#167100) Adds an option to do compile time collective benchmarking for comms/compute overlap scheduling. As with the comm benchmarks, these are all gathered, and each rank uses the median result to ensure consistency. thanks to @ruisizhang123 who had done this previously. We log the compile time benchmark, the inductor analytic result, and the nccl estimator result to tlparse. TODO: - mechanism to seed collective estimates with the existing tlparse (or perfetto) to use for deterministic, pgo'd estimates - interpolate results between powers of 2, and also do the actual benchmarking for latency calculation. both of these need to be meta aware since reduce scatter needs to be divisible by group_size, not hard but leaving for a subsequent pr. Example output tlparse: https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/rank_0/-_0_0_0/node_runtime_estimation_10.json?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167100 Approved by: https://github.com/IvanKobzarev --- .../test_aten_comm_compute_reordering.py | 45 +++++ test/dynamo/test_logging.py | 1 + torch/_inductor/config.py | 5 + .../fx_passes/node_runtime_estimation.py | 179 ++++++++++++++++++ .../_inductor/fx_passes/overlap_scheduling.py | 132 ++++++++++++- torch/_inductor/fx_passes/post_grad.py | 1 + torch/_logging/_registrations.py | 5 + 7 files changed, 361 insertions(+), 7 deletions(-) create mode 100644 torch/_inductor/fx_passes/node_runtime_estimation.py diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 5b1db2d8dfe14..eff762ec323c5 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -54,6 +54,7 @@ def apply_reordering_and_get_graph(graph, out_li) -> None: "max_compute_pre_fetch", "custom_runtime_estimation", "insert_overlap_deps", + "collective_estimator", ) for key in config_keys: if (val := getattr(dist_opts, key)) is not None: @@ -943,6 +944,50 @@ def func(a, b, *, ranks): correct = func(inputs_a, inputs_b, ranks=ranks) self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_collective_benchmarking_with_real_pg(self): + """Test collective benchmarking with real process group (falls back on fake).""" + + def func(a): + # Test all three collective types with 8x8 (power of 2 size = 256 elements = 1024 bytes for fp32) + ar = _functional_collectives.all_reduce(a, "sum", "0") + ag = _functional_collectives.all_gather_tensor( + a, 0, list(range(self.world_size)) + ) + rs = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0") + + b = torch.matmul(a, a) + c = torch.matmul(ar, b) + return c.sum() + ag.sum() + rs.sum() + + patches = { + **get_patches(), + "aten_distributed_optimizations.collective_estimator": "benchmark", + "aten_distributed_optimizations.custom_runtime_estimation": None, # Remove custom estimation so benchmarking happens + } + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs = torch.ones(8, 8, dtype=torch.float, device=device_type) + self.rank + + with torch._inductor.config.patch(patches): + compiled = torch.compile(func) + out, aten_graph_str = run_and_get_aten_graph(compiled, inputs) + + # Verify all three collective types are present + FileCheck().check("all_reduce").check("all_gather").check( + "reduce_scatter" + ).run(aten_graph_str) + + # Test passes if compilation succeeded with benchmarking enabled + # Cache verification is tricky due to multiprocess test setup + correct = func(inputs) + self.assertTrue(same(out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @torch._inductor.config.patch(get_bucket_patches()) def test_multidtype_bucketing(self): diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 860d82784ea70..162bc5c111d07 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -988,6 +988,7 @@ def bar(): "hierarchical_compile", "compute_dependencies", "annotation", + "node_runtime_estimation", } for name in torch._logging._internal.log_registry.artifact_names: if name not in exclusions: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 3eaa840961fa8..fb43a9b859ffb 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -914,6 +914,11 @@ class aten_distributed_optimizations: None ) + # Method for estimating collective runtime + # "analytical": Use bandwidth formulas (default) + # "benchmark": Use CUDA events with power-of-2 rounding and interpolation + collective_estimator: Literal["analytical", "benchmark"] = "analytical" + def parallel_compile_enabled_internally() -> bool: """ diff --git a/torch/_inductor/fx_passes/node_runtime_estimation.py b/torch/_inductor/fx_passes/node_runtime_estimation.py new file mode 100644 index 0000000000000..43d3647b916a2 --- /dev/null +++ b/torch/_inductor/fx_passes/node_runtime_estimation.py @@ -0,0 +1,179 @@ +""" +Collective runtime estimation using CUDA events and power-of-2 rounding. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Any, Optional + +import torch +from torch._inductor.utils import clear_on_fresh_cache +from torch._logging import getArtifactLogger +from torch.fx.operator_schemas import normalize_function + + +# Setup logger for artifact logging +log = getArtifactLogger(__name__, "node_runtime_estimation") + + +# TODO: Consider using a distributed-aware cache or rank-local disk cache +# not using local cache because different ranks might write to it concurrently. +# solvable in future, potentially with workflow to seed cache +@clear_on_fresh_cache +@lru_cache +def _get_collective_cache() -> dict[str, float]: + """Get process-local cache for collective benchmarks.""" + return {} + + +def get_cached_runtime(key: str) -> Optional[float]: + """Get cached runtime from process-local cache.""" + return _get_collective_cache().get(key) + + +def set_cached_runtime(key: str, value: float) -> None: + """Set cached runtime in process-local cache.""" + _get_collective_cache()[key] = value + + +def get_hint(x: int | torch.SymInt) -> Optional[int]: + if isinstance(x, int): + return x + assert isinstance(x, torch.SymInt) + return x.node.hint if x.node.has_hint() else None + + +def can_benchmark_collective() -> bool: + """Check if we can benchmark collectives (not fake process group).""" + import torch.distributed as c10d + + if not c10d.is_initialized(): + return False + + pg = c10d.distributed_c10d._get_default_group() + if torch.distributed.distributed_c10d.get_backend(pg) == "fake": + return False + + return True + + +def _benchmark_collective_with_cuda_events_impl( + n: torch.fx.Node, + args: tuple[Any, ...], + kwargs: dict[str, Any], + nruns: int, +) -> float | None: + """ + Core benchmarking logic using CUDA events and barriers. + Returns runtime in ms or None on failure. + """ + import torch.distributed as c10d + + # Warmup: call collective once and wait + torch.cuda.synchronize() + result = n.target(*args, **kwargs) # type: ignore[operator] + torch.ops._c10d_functional.wait_tensor(result) + + # Benchmark with CUDA events + comm_time = 0.0 + for _ in range(nruns): + c10d.barrier() + torch.cuda.synchronize() + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + result = n.target(*args, **kwargs) # type: ignore[operator] + torch.ops._c10d_functional.wait_tensor(result) + end_evt.record() + end_evt.synchronize() + + comm_time += start_evt.elapsed_time(end_evt) + + return comm_time / nruns + + +def benchmark_collective_with_cuda_events( + n: torch.fx.Node, + nruns: int = 2, +) -> tuple[float | None, str]: + """ + Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure. + """ + # context manager not allowed with profiler. + with torch.utils._python_dispatch._disable_current_modes(): + return benchmark_collective_with_cuda_events_impl(n, nruns) + + +def benchmark_collective_with_cuda_events_impl( + n: torch.fx.Node, + nruns: int = 2, +) -> tuple[float | None, str]: + """ + Benchmark collective with CUDA events. Returns (runtime_ms, cache_key) or (None, "") on failure. + """ + from torch._inductor import fx_utils + from torch.distributed.distributed_c10d import _get_group_size_by_name + + # Early check: can we actually run collectives? + if not can_benchmark_collective(): + return None, "" + + success, args, kwargs = fx_utils.get_fake_args_kwargs(n) + + opt_args_kwargs = normalize_function( + n.target, # type: ignore[arg-type] + args=n.args, + kwargs=n.kwargs, + normalize_to_only_use_kwargs=True, + ) + assert opt_args_kwargs is not None + group_name = opt_args_kwargs[1]["group_name"] + group_size = _get_group_size_by_name(group_name) + + if not success: + return None, "" + + # Extract actual input size in BYTES (first tensor argument) + actual_bytes: Optional[int] = None + + def extract_tensor_info(t: torch.Tensor) -> torch.Tensor: + nonlocal actual_bytes + if actual_bytes is None: + shape = [get_hint(dim) for dim in t.shape] + if any(s is None for s in shape): + return t + + total_elems = 1 + for dim in shape: + assert dim is not None + total_elems *= dim + + actual_bytes = total_elems * t.dtype.itemsize + else: + raise RuntimeError(f"should only be one input tensor to collective {n}") + return t + + torch.utils._pytree.tree_map_only(torch.Tensor, extract_tensor_info, (args, kwargs)) + + if actual_bytes is None: + return None, "" + + # Cache key by BYTES (dtype-agnostic) + key = f"{n.target}: ({group_size} group size, {actual_bytes} bytes)" + + # Check cache + if (cached := get_cached_runtime(key)) is not None: + return cached, key + + # Benchmark using CUDA events with actual args/kwargs + runtime = _benchmark_collective_with_cuda_events_impl(n, args, kwargs, nruns) + + if runtime is None: + return None, key + + # Cache the result + set_cached_runtime(key, runtime) + return runtime, key diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 80ef2a95139a3..4f5d280869f99 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -6,7 +6,7 @@ from collections import Counter, defaultdict from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import torch import torch.fx as fx @@ -61,6 +61,7 @@ def estimate_collective_time( if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None: return est + # Use analytical model (benchmarking is handled separately in alignment) return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( n, override_size ) @@ -109,6 +110,7 @@ def benchmark_node_with_cache_key( n: fx.Node, custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, ) -> tuple[float, str | None]: + """Benchmark a compute node and return (runtime, cache_key).""" assert is_compute_node(n) from torch._dynamo.testing import rand_strided @@ -244,6 +246,7 @@ def __init__( compute_overlap_multipler: float, max_coll_distance: int, custom_runtime_estimation: Callable[[fx.Node], float | None] | None, + collective_estimator: Literal["analytical", "benchmark"], ): self.gm = gm self.graph = gm.graph @@ -254,6 +257,7 @@ def __init__( self.collective_bucketing = collective_bucketing self.insert_overlap_deps = insert_overlap_deps self.max_compute_pre_fetch = max_compute_pre_fetch + self.collective_estimator = collective_estimator # Build structures stable_topological_sort(self.graph) @@ -356,25 +360,104 @@ def _calculate_compute_node_domination_index(self) -> dict[fx.Node, int]: return domination_index + def _log_collective_benchmarks( + self, + collective_nodes: list[fx.Node], + collective_keys: list[str], + benchmarked_medians: list[float], + world_size: int, + ) -> None: + """Log collective benchmarks with analytical comparisons for tlparse.""" + collective_benchmarks = {} + for key, benchmarked_ms, coll_node in zip( + collective_keys, benchmarked_medians, collective_nodes + ): + # NCCL estimator (deterministic, no need to align) + nccl_ms = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + coll_node, None, use_nccl_estimator=True + ) + + # Inductor analytical (deterministic, no need to align) + inductor_ms = torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + coll_node, None, use_nccl_estimator=False + ) + + collective_benchmarks[key] = { + "benchmarked_ms": benchmarked_ms, + "analytical_nccl_ms": nccl_ms, + "analytical_inductor_ms": inductor_ms, + } + + # Emit tlparse artifact + from torch._logging import trace_structured + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "node_runtime_estimation", + "encoding": "json", + }, + payload_fn=lambda: { + "world_size": world_size, + "collective_benchmarks": collective_benchmarks, + }, + ) + def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( self, ) -> None: + """Align runtime estimations across ranks (compute + collectives).""" log.info( "Overlap scheduling: Aligning runtime estimations across all distributed ranks" ) + + # Benchmark compute nodes runtime_estimations_keys: list[str | None] = [] runtime_estimations: list[float] = [] + compute_key_count = 0 + for n in self.compute_nodes: val, key = benchmark_node_with_cache_key(n, self.custom_runtime_estimation) runtime_estimations.append(val) runtime_estimations_keys.append(key) + compute_key_count += 1 + + # Benchmark collectives if enabled (only CUDA events - others are deterministic) + # Skip if custom estimation is provided for collectives + collective_nodes: list[fx.Node] = [] + benchmarked_collective_nodes: list[ + fx.Node + ] = [] # Track which were actually benchmarked + if self.collective_estimator == "benchmark": + from torch._inductor.fx_passes.node_runtime_estimation import ( + benchmark_collective_with_cuda_events, + ) + + collective_nodes = [ + info.start_node for info in self.collective_info.values() + ] + + # Benchmark CUDA events (non-deterministic, needs alignment) + # Skip collectives with custom estimation + for n in collective_nodes: + if get_custom_estimation(n, self.custom_runtime_estimation) is not None: + continue + + # Benchmark actual size + cuda_val, cuda_key = benchmark_collective_with_cuda_events(n, nruns=2) + if cuda_val is not None: + runtime_estimations.append(cuda_val) + runtime_estimations_keys.append(cuda_key) + benchmarked_collective_nodes.append(n) + # Single all_gather and compute medians import torch.distributed as dist from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.distributed.distributed_c10d import _get_default_group world_size = dist.get_world_size() pg = _get_default_group() + with unset_fake_temporarily(): gathered_runtime_estimations: list[list[float]] = [ [] for _ in range(world_size) @@ -385,15 +468,46 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( median_runtime_estimations = torch.median( torch.tensor(gathered_runtime_estimations), dim=0 ).values.tolist() - for key, median_runtime_estimation in zip( - runtime_estimations_keys, median_runtime_estimations + + # Cache medians + collective_keys = [] + collective_medians = [] + for idx, (key, median_runtime_estimation) in enumerate( + zip(runtime_estimations_keys, median_runtime_estimations) ): if key is None: continue - set_cached_node_time(key, median_runtime_estimation) - log.info( - "Overlap scheduling: Runtime estimations across all distributed ranks were aligned" - ) + if idx < compute_key_count: + # Compute node + set_cached_node_time(key, median_runtime_estimation) + else: + # Collective CUDA event benchmark + from torch._inductor.fx_passes.node_runtime_estimation import ( + set_cached_runtime, + ) + + set_cached_runtime(key, median_runtime_estimation) + + # Update CollectiveInfo with aligned benchmark + coll_idx = idx - compute_key_count + coll_node = benchmarked_collective_nodes[coll_idx] + info = self.collective_info[coll_node] + info.estimated_time_ms = median_runtime_estimation + info.exposed_time_ms = median_runtime_estimation + + collective_keys.append(key) + collective_medians.append(median_runtime_estimation) + + # Log benchmarks with analytical comparisons + if collective_keys: + self._log_collective_benchmarks( + benchmarked_collective_nodes, + collective_keys, + collective_medians, + world_size, + ) + + log.info("Overlap scheduling: Runtime estimations aligned") def run(self) -> torch.fx.GraphModule: """Run the scheduling algorithm.""" @@ -894,6 +1008,7 @@ def schedule_overlap_bucketing( compute_overlap_multipler: float = 1.0, max_coll_distance: int = 1000, custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None, + collective_estimator: Literal["analytical", "benchmark"] = "analytical", ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. @@ -910,6 +1025,8 @@ def schedule_overlap_bucketing( max_coll_distance: Maximum node distance for overlap or bucketing. Mostly intended to reduce compile time. custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node. If None, uses default estimations. This is currently limited to collectives and compute nodes. + collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas, + "benchmark" uses CUDA events with power-of-2 rounding and interpolation. """ return OverlapScheduler( @@ -921,4 +1038,5 @@ def schedule_overlap_bucketing( custom_runtime_estimation=custom_runtime_estimation, collective_bucketing=collective_bucketing, insert_overlap_deps=insert_overlap_deps, + collective_estimator=collective_estimator, ).run() diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 9c7c01c785f4e..958a52fcdf510 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -290,6 +290,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): "max_compute_pre_fetch", "custom_runtime_estimation", "insert_overlap_deps", + "collective_estimator", ) for key in config_keys: if (val := getattr(dist_opts, key)) is not None: diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 162ad53a63ccd..f0077f0f9bb7d 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -225,6 +225,11 @@ "Detailed Inductor benchmarking information.", off_by_default=True, ) +register_artifact( + "node_runtime_estimation", + "Node runtime estimation for compile-time optimization decisions.", + off_by_default=True, +) register_artifact( "autotuning", "Autotuning choice logs, such as kernel source, perf, and tuning parameters.", From 8eb21304ab4fba8ba43865607b3609efa14c2a10 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 6 Nov 2025 19:34:07 -0800 Subject: [PATCH 223/651] [DTensor] ignore fresh unbacked symbols in shard prop (#166989) This fixes 2 issues with the DTensor data-dependent test case: 1) ShapeEnv not found when doing shard prop on data-dependent ops - fix was to detect the outer tracing fake mode. Maybe ShardingPropagator should just own a FakeMode & ShapeEnv for these purposes? The previous behavior was to initialize a new fake mode on every call. 2) Pending unbacked symbols not found. This happens because DTensor dispatch runs fake prop twice, once while figuring out the output sharding: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_sharding_prop.py#L175 and again to actually get the resulting local tensor: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_dispatch.py#L254-L255 With data-dependent ops, both calls will produce an unbacked symbol, but symbols in the first invocation are never surfaced, producing this error, so we ignore pending symbols from this site. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166989 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_dtensor_export.py | 13 +++++++++++++ torch/distributed/tensor/_sharding_prop.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index d2104066811be..b9749e3bc4e23 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -535,6 +535,19 @@ def nest_fn(leaf: torch.Tensor | DTensor): self.assertEqual(fn(z), gm(z)[0]) + def test_dtensor_data_dependent_index(self): + device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) + + class Foo(torch.nn.Module): + def forward(self, x, y): + return x[y] + + x = torch.randn(10) + y = torch.randint(1, (10,)).bool() + x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()]) + y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()]) + _dynamo_graph_capture_for_export(Foo())(x_dt, y_dt) + instantiate_parametrized_tests(DTensorExportTest) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index c1af2c1317174..08ccb493f10c0 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import threading from collections.abc import Callable, Sequence from functools import lru_cache @@ -6,6 +7,7 @@ from typing import cast, Optional, Union import torch +from torch._guards import detect_fake_mode from torch._ops import OpOverload from torch._subclasses import FakeTensorMode from torch.distributed._functional_collectives import _are_we_tracing @@ -169,7 +171,16 @@ def _propagate_tensor_meta_non_cached( # these operators to be inserted in the fx graph. from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing - with FakeTensorMode(), disable_proxy_modes_tracing(): + # DTensor.dispatch runs fake tensor prop twice, once here, and once for the actual + # local tensor result. The result here is never surfaced to tracing, and so if + # the op is data-dependent, can result in PendingUnbackedSymbolNotFound errors. + fake_mode = detect_fake_mode() or FakeTensorMode() + suppress_fresh_symbols_ctx = ( + fake_mode.shape_env.ignore_fresh_unbacked_symbols() + if fake_mode.shape_env + else contextlib.nullcontext() + ) + with fake_mode, disable_proxy_modes_tracing(), suppress_fresh_symbols_ctx: fake_args = op_schema.gen_fake_args() fake_kwargs = op_schema.gen_fake_kwargs() fake_out = op_schema.op(*fake_args, **fake_kwargs) From ba327b7a5c0a326e3b587f62620d9f2e688b699a Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 7 Nov 2025 20:38:03 +0000 Subject: [PATCH 224/651] [BE][Typing][Dynamo] Type torch/_dynamo/variables/functions.py (#167103) Provides type coverage to torch/_dynamo/variables/dicts.py Coverage report: `mypy torch/_dynamo/variables/functions.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 2698 lines and 166 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167103 Approved by: https://github.com/mlazos, https://github.com/fxdawnn --- torch/_dynamo/variables/builtin.py | 2 +- torch/_dynamo/variables/functions.py | 771 +++++++++++++++++---------- torch/_dynamo/variables/iter.py | 2 +- torch/_dynamo/variables/lists.py | 1 + torch/_dynamo/variables/torch.py | 9 +- 5 files changed, 485 insertions(+), 300 deletions(-) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 2ac7bc7fe60b4..e15eb83c72573 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1991,7 +1991,7 @@ def call_iter( # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() # with an integer argument starting at 0, until __getitem__ raises IndexError ret = variables.UserFunctionVariable( - polyfills.builtins.iter_ + polyfills.builtins.iter_ # type: ignore[arg-type] ).call_function(tx, [obj, *args], {}) if args: diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 5fd903e7bbfdf..2f64c825a07fc 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Function-related variable tracking classes for Dynamo's symbolic execution. @@ -32,13 +30,14 @@ import traceback import types from collections.abc import Callable, Sequence -from types import FunctionType +from types import CellType, FunctionType from typing import Any, Optional, TYPE_CHECKING, TypeVar from typing_extensions import Never from weakref import WeakKeyDictionary import torch from torch._dynamo.exc import get_stack_above_dynamo +from torch._guards import Source from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_rot_n, is_generator @@ -87,25 +86,32 @@ try: from torch.distributed.fsdp._fully_shard import _fsdp_param_group except ModuleNotFoundError: - _fsdp_param_group = None + _fsdp_param_group = None # type: ignore[assignment] if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen - from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.symbolic_convert import ( + InstructionTranslator, + InstructionTranslatorBase, + ) + from torch._dynamo.variables.ctx_manager import ContextWrappingVariable from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, TritonKernelType, ) + from .lists import BaseListVariable, ListVariable + from .tensor import TensorVariable -_F = TypeVar("_F", bound=Callable) + +_F = TypeVar("_F", bound=Callable[..., Any]) CO_VARARGS = 0x04 CO_VARKEYWORDS = 0x08 # Module-level cache keyed by the function object -_spec_cache = WeakKeyDictionary() +_spec_cache: WeakKeyDictionary[Any, Any] = WeakKeyDictionary() class FunctionSpec: @@ -127,7 +133,7 @@ def __init__(self, func: FunctionType): off += 1 if self.varargs_name else 0 self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None - def update_defaults(self, func: FunctionType): + def update_defaults(self, func: FunctionType) -> None: # Defaults can change from function call to function call. So re-update # them on every call. self.defaults = func.__defaults__ or () @@ -147,7 +153,13 @@ def _get_spec(func: FunctionType) -> FunctionSpec: return spec -def bind_args_cached(func, tx, fn_source, args, kwargs): +def bind_args_cached( + func: FunctionType, + tx: "InstructionTranslator", + fn_source: Optional[Source], + args: Sequence[Any], + kwargs: dict[str, Any], +) -> dict[str, VariableTracker]: spec = _get_spec(func) spec.update_defaults(func) ba = {} @@ -240,7 +252,9 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): return ba -def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): +def wrap_bound_arg( + tx: "InstructionTranslator", val: Any, source: Optional[Source] = None +) -> VariableTracker: # Source propagation is best effort since not every object we encounter has a source to begin with. if isinstance(val, VariableTracker): return val @@ -252,14 +266,18 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): return variables.LazyVariableTracker.create(val, source) -def wrap_args_kwargs(tx: "InstructionTranslator", result): +def wrap_args_kwargs(tx: "InstructionTranslator", result: dict[str, Any]) -> None: for k, v in list(result.items()): if isinstance(v, (tuple, dict)): # args/kwargs result[k] = wrap_bound_arg(tx, v) -def init_cellvars(parent, result: dict[str, VariableTracker], code): +def init_cellvars( + parent: "InstructionTranslator", + result: dict[str, VariableTracker], + code: types.CodeType, +) -> None: """ Update `result` to add mapping from local name to new cells created directly by `code`, or update SideEffects in `parent` if the a local cell is @@ -277,8 +295,14 @@ def init_cellvars(parent, result: dict[str, VariableTracker], code): def _create_nested_fn( - code, f_globals, name, defaults, closure, kwdefaults, annotations -): + code: types.CodeType, + f_globals: dict[str, Any], + name: str, + defaults: Optional[tuple[object, ...]], + closure: Optional[tuple[CellType]], + kwdefaults: Optional[dict[str, Any]], + annotations: Optional[dict[str, Any]], +) -> types.FunctionType: from types import FunctionType func = FunctionType(code, f_globals, name, defaults, closure) @@ -291,7 +315,7 @@ def _create_nested_fn( # TypeError: __annotations__ must be set to a dict object assert annotations is None or isinstance(annotations, dict) - func.__annotations__ = annotations + func.__annotations__ = annotations # type: ignore[assignment] return func @@ -307,7 +331,9 @@ def _create_nested_fn( } -def fn_var_getattr(tx, fn, source, name): +def fn_var_getattr( + tx: "InstructionTranslator", fn: object, source: Optional[Source], name: str +) -> VariableTracker: source = source and AttrSource(source, name) if source and name == "__annotations__": @@ -316,6 +342,7 @@ def fn_var_getattr(tx, fn, source, name): # graph is even rarer. So skip guards. source = SkipGuardSource(source) + subobj = None try: subobj = inspect.getattr_static(fn, name) except AttributeError: @@ -332,19 +359,19 @@ def fn_var_getattr(tx, fn, source, name): class BaseUserFunctionVariable(VariableTracker): - def get_filename(self): - return self.get_code().co_filename + def get_filename(self) -> str: + return self.get_code().co_filename # type: ignore[attr-defined] - def get_name(self): - return self.get_code().co_name + def get_name(self) -> str: + return self.get_code().co_name # type: ignore[attr-defined] def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] def call_obj_hasattr( self, tx: "InstructionTranslator", name: str @@ -352,16 +379,16 @@ def call_obj_hasattr( result = False try: - result = hasattr(self.get_function(), name) + result = hasattr(self.get_function(), name) # type: ignore[attr-defined] except NotImplementedError: if name == "__name__" and isinstance(self, NestedUserFunctionVariable): result = True return variables.ConstantVariable.create(result) - def inspect_parameter_names(self): - return list(inspect.signature(self.get_function()).parameters) + def inspect_parameter_names(self) -> list[str]: + return list(inspect.signature(self.get_function()).parameters) # type: ignore[attr-defined] - def closure_vars(self, tx): + def closure_vars(self, tx: "InstructionTranslator") -> dict[str, VariableTracker]: return {} @@ -375,11 +402,16 @@ class UserFunctionVariable(BaseUserFunctionVariable): } @classmethod - def create_with_source(cls, value, source): + def create_with_source(cls, value: Any, source: Any) -> "UserFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__(self, fn, is_constant=False, **kwargs) -> None: + def __init__( + self, + fn: types.FunctionType | torch.jit.ScriptFunction, # type: ignore[type-arg] + is_constant: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if getattr(fn, "_dynamo_marked_constant", False): # This method should be treated as a constant for the purposes of compilation @@ -403,40 +435,45 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None: # VariableBuilder, which handles the wrapping of _torchdynamo_inline. # unpack @torch._dynamo.optimize()(fn) wrapped function fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) - self.fn: types.FunctionType = fn + self.fn = fn - def as_python_constant(self): + def as_python_constant(self) -> Any: if istype(self, UserFunctionVariable): return self.fn # subclasses (such as methods) usually aren't a constant return super().as_python_constant() - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] - def get_function(self): + def get_function(self) -> types.FunctionType: return self.fn - def get_code(self): + def get_code(self) -> types.CodeType: return self.fn.__code__ - def python_type(self): + def python_type(self) -> type: return types.FunctionType - def has_self(self): + def has_self(self) -> bool: return getattr(self.fn, "__self__", None) is not None - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.fn.__globals__ - def get_source(self): + def get_source(self) -> Source: source = self.source if source and isinstance(self, variables.UserMethodVariable): - source = self.source_fn - return source + source = self.source_fn # type: ignore[assignment] + return source # type: ignore[return-value] - def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: """ Assume `args` and `kwargs` are VariableTracker arguments for a call to this function, create new bindings for initial locals. @@ -450,7 +487,7 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: root_tx = parent.output.root_tx source = self.get_source() - result = bind_args_cached(fn, root_tx, source, args, kwargs) + result = bind_args_cached(fn, root_tx, source, args, kwargs) # type: ignore[arg-type] init_cellvars(parent, result, fn.__code__) closure = self.fn.__closure__ or () @@ -491,7 +528,7 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]: return result - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) source = self.get_source() @@ -506,9 +543,9 @@ def call_obj_hasattr( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # Handle patch_dynamo_config call if self.fn is torch._dynamo.patch_dynamo_config: try: @@ -548,7 +585,7 @@ def call_function( msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" unimplemented_v2( gb_type="TypeError from user code", - context=f"call_function({self.value}, {args}, {kwargs})", + context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] explanation=msg, hints=[ *graph_break_hints.USER_ERROR, @@ -567,7 +604,7 @@ def call_function( "`torch.compile` region", ], ) - + # pyrefly: ignore[missing-attribute] fn = fn_var.fn return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) @@ -593,7 +630,7 @@ def call_function( try: from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState except Exception: - FSDPState = None + FSDPState = None # type: ignore[assignment, misc] if FSDPState is not None and self.fn in [ FSDPState._pre_forward, FSDPState._post_forward, @@ -604,13 +641,15 @@ def call_function( class BuiltinMethodVariable(BaseUserFunctionVariable): - def __init__(self, fn, is_constant=False, **kwargs) -> None: + def __init__( + self, fn: types.BuiltinMethodType, is_constant: bool = False, **kwargs: Any + ) -> None: super().__init__(**kwargs) assert isinstance(fn, types.BuiltinMethodType) self.fn = fn @staticmethod - def is_supported_builtin_method(obj): + def is_supported_builtin_method(obj: Any) -> bool: method_self = obj.__self__ method_name = obj.__name__ @@ -623,9 +662,9 @@ def is_supported_builtin_method(obj): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: method_self = self.fn.__self__ name = self.fn.__name__ obj_source = self.source and AttrSource(self.source, "__self__") @@ -637,39 +676,39 @@ class LocalGeneratorObjectVariable(VariableTracker): def __init__( self, code: types.CodeType, - f_globals, + f_globals: dict[str, Any], inline_tracer: Optional["InstructionTranslator"], - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.code = code self.f_globals = f_globals self.inline_tracer = inline_tracer - def get_code(self): + def get_code(self) -> types.CodeType: return self.code - def get_filename(self): + def get_filename(self) -> str: return self.get_code().co_filename - def get_name(self): + def get_name(self) -> str: return self.get_code().co_name - def get_function(self): + def get_function(self) -> Never: raise NotImplementedError - def has_self(self): + def has_self(self) -> bool: return False - def __name__(self): + def __name__(self) -> str: return self.get_name() - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}({self.get_name()})" __repr__ = __str__ - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -688,25 +727,30 @@ def reconstruct(self, codegen: "PyCodegen"): self.remaining_items = self.force_unpack_var_sequence(tx) variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen) - def bind_args(self, tx, args, kwargs): - return self.fn.bind_args(tx, args, kwargs) + def bind_args( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: + return self.vt.bind_args(tx, args, kwargs) # type: ignore[attr-defined] - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.f_globals - def python_type(self): + def python_type(self) -> type: return types.GeneratorType - def _get_inline_tracer(self, tx): + def _get_inline_tracer(self, tx: "InstructionTranslator") -> Any: from torch._dynamo.symbolic_convert import InliningInstructionTranslator if self.inline_tracer is None: - self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( + self.inline_tracer = InliningInstructionTranslator.build_inline_tracer( # type: ignore[assignment] tx, self, [], {} ) return self.inline_tracer - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: tracer = self._get_inline_tracer(tx) if self._is_generator_exhausted(): @@ -727,23 +771,29 @@ def next_variable(self, tx): torch._dynamo.eval_frame.skip_code(self.get_code()) raise SkipFrame from e - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) - def has_unpack_var_sequence(self, tx): + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return False - def has_force_unpack_var_sequence(self, tx) -> builtins.bool: + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return True - def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: - result = [] + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] self.force_apply_to_var_sequence(tx, result.append) return result - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[VariableTracker], Any] + ) -> None: while True: try: fn(self.next_variable(tx)) @@ -751,7 +801,9 @@ def force_apply_to_var_sequence(self, tx, fn) -> None: handle_observed_exception(tx) break - def _setup_exception(self, tx, exc): + def _setup_exception( + self, tx: "InstructionTranslator", exc: VariableTracker + ) -> None: tracer = self._get_inline_tracer(tx) try: tracer._raise_exception_variable(exc) @@ -760,19 +812,19 @@ def _setup_exception(self, tx, exc): # exception is raised again. tracer.exception_handler(e) - def _is_generator_just_started(self): + def _is_generator_just_started(self) -> bool: return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0 - def _is_generator_exhausted(self): + def _is_generator_exhausted(self) -> bool: return getattr(self.inline_tracer, "generator_exhausted", False) def call_method( self, tx: "InstructionTranslator", name: str, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__next__": return self.next_variable(tx) elif name == "__iter__": @@ -952,7 +1004,7 @@ def call_method( raise_observed_exception(RuntimeError, tracer) return retval - super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) class ContextlibContextManagerLocalGeneratorObjectVariable( @@ -980,9 +1032,9 @@ def __init__( self, vt: VariableTracker, *, - generator_cls=LocalGeneratorObjectVariable, - **kwargs, - ): + generator_cls: type = LocalGeneratorObjectVariable, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.vt = vt self.generator_cls = generator_cls @@ -992,7 +1044,12 @@ def __getattr__(self, name): return getattr(self, name) return getattr(self.vt, name) - def _build_inline_tracer(self, tx, args, kwargs): + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InstructionTranslatorBase": from torch._dynamo.symbolic_convert import InliningInstructionTranslator return InliningInstructionTranslator.build_inline_tracer( @@ -1005,13 +1062,13 @@ def _build_inline_tracer(self, tx, args, kwargs): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - if not is_generator(self.vt.get_code()): + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] unimplemented_v2( gb_type="non-generator contextlib.contextmanager", - context=str(self.vt.get_code()), + context=str(self.vt.get_code()), # type: ignore[attr-defined] explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" ", i.e. does not use `yield`", hints=[ @@ -1020,15 +1077,15 @@ def call_function( ], ) - inline_tracer = self._build_inline_tracer(tx, args, kwargs) - code = self.vt.get_code() - f_globals = self.vt.get_globals() + inline_tracer = self._build_inline_tracer(tx, list(args), kwargs) + code = self.vt.get_code() # type: ignore[attr-defined] + f_globals = self.vt.get_globals() # type: ignore[attr-defined] # calling a generator returns a generator object return self.generator_cls( code, f_globals, - inline_tracer, + inline_tracer, # type: ignore[arg-type] source=self.source, ) @@ -1042,14 +1099,19 @@ class FunctionDecoratedByContextlibContextManagerVariable( This is only used when the function is annotated with @contextlib.contextmanager """ - def __init__(self, vt, **kwargs): + def __init__(self, vt: VariableTracker, **kwargs: Any): super().__init__( vt, generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable, **kwargs, ) - def _build_inline_tracer(self, tx, args, kwargs): + def _build_inline_tracer( + self, + tx: "InstructionTranslatorBase", + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "InstructionTranslatorBase": # NOTE: This only exists to not break support for context manager when # config.enable_faithful_generator_behavior = False and # config.enable_trace_contextlib = True. In case the former is false, @@ -1066,8 +1128,14 @@ def _build_inline_tracer(self, tx, args, kwargs): class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" - def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: - super().__init__(fn=fn, **kwargs) + def __init__( + self, + fn: Callable[..., Any], + obj: VariableTracker, + source_fn: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__(fn=fn, **kwargs) # type: ignore[arg-type] self.obj = obj self.source_fn = source_fn # Note on source and source_fn @@ -1083,24 +1151,24 @@ def __init__(self, fn, obj, source_fn=None, **kwargs) -> None: # operates on the unbound function, most guards should target # `source_fn` rather than the original `source`. if source_fn is None and kwargs.get("source") is not None: - self.source_fn = AttrSource(kwargs.get("source"), "__func__") + self.source_fn = AttrSource(kwargs.get("source"), "__func__") # type: ignore[assignment, arg-type] def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [self.obj] - def python_type(self): + def python_type(self) -> type[types.MethodType]: return types.MethodType def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # NOTE this is to handle methods annotated by `nonstrict_trace`. # a `nonstrict_trace`-ed function will be wrapped by # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`, # but in the case of method, we manually wrap it with `UserMethodVariable` @@ -1141,36 +1209,41 @@ def call_function( or self.is_constant ): return self.obj.call_method( - tx, self.fn.__name__, args, kwargs, constant=self.is_constant + tx, self.fn.__name__, list(args), kwargs, constant=self.is_constant ) elif ( _fsdp_param_group is not None - and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state + and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state # type: ignore[attr-defined] ): return variables.TorchCtxManagerClassVariable(self.fn).call_function( tx, (self.obj, *args), kwargs ) if self.is_constant: - fn = getattr(self.obj.value, self.fn.__name__) + fn = getattr(self.obj.value, self.fn.__name__) # type: ignore[attr-defined] return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) return super().call_function(tx, args, kwargs) - def inspect_parameter_names(self): + def inspect_parameter_names(self) -> list[str]: return super().inspect_parameter_names()[1:] - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__self__": return self.obj if name == "__func__": # We might have a better way to access the function object, this # information is stored in self.source_fn, use that to construct the # variable tracker. - return VariableTracker.build(tx, self.fn, self.source_fn) + return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] return super().var_getattr(tx, name) class WrappedUserMethodVariable(UserMethodVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: UserMethodVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn", None) kwargs.pop("obj", None) super().__init__(wrapped.fn, wrapped.obj, **kwargs) @@ -1180,22 +1253,27 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) class WrappedUserFunctionVariable(UserFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: UserFunctionVariable, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped @@ -1204,22 +1282,28 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) -def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): - def convert(x): +def invoke_and_store_as_constant( + tx: "InstructionTranslator", + fn: Callable[..., Any], + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], +) -> VariableTracker: + def convert(x: VariableTracker) -> Any: if isinstance(x, variables.TensorVariable): return x.get_real_value() return x.as_python_constant() @@ -1242,17 +1326,17 @@ class NestedUserFunctionVariable(BaseUserFunctionVariable): def __init__( self, - fn_name, - code, - f_globals, - defaults, - kwdefaults, - annotations, - closure, + fn_name: VariableTracker, + code: VariableTracker, + f_globals: dict[str, Any], + defaults: Optional[VariableTracker], + kwdefaults: Optional[VariableTracker], + annotations: Optional[VariableTracker], + closure: Optional[VariableTracker], # This is present when this function is created by # `functools.wrap(wrapped_fn)(this_fn)`. - wrapped_fn=None, - **kwargs, + wrapped_fn: Optional[VariableTracker] = None, + **kwargs: Any, ) -> None: if kwargs.get("mutation_type") is None: kwargs.update(mutation_type=AttributeMutationNew()) @@ -1269,16 +1353,16 @@ def __init__( self.closure = closure self.wrapped_fn: Optional[VariableTracker] = wrapped_fn - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] - def get_code(self): + def get_code(self) -> types.CodeType: return self.code.as_python_constant() - def python_type(self): + def python_type(self) -> type: return types.FunctionType - def get_function(self): + def get_function(self) -> types.FunctionType: if self.closure: raise NotImplementedError func = types.FunctionType( @@ -1307,19 +1391,25 @@ def call_setattr( tx: "InstructionTranslator", name_var: VariableTracker, val: VariableTracker, - ): - tx.output.side_effects.store_attr(self, name_var.value, val) + ) -> VariableTracker: + tx.output.side_effects.store_attr(self, name_var.value, val) # type: ignore[attr-defined] return ConstantVariable(None) - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__setattr__": return self.call_setattr(tx, *args) - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, list(args), kwargs) - def has_closure(self): + def has_closure(self) -> bool: return self.closure is not None - def const_getattr(self, tx, name): + def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: if name == "__name__": return self.get_name() if name == "__code__": @@ -1329,50 +1419,57 @@ def const_getattr(self, tx, name): return d.as_python_constant() if d else None return super().const_getattr(tx, name) - def call_obj_hasattr(self, tx: "InstructionTranslator", name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: if name == "__code__": return variables.ConstantVariable.create(hasattr(self, "code")) if name == "__defaults__": return variables.ConstantVariable.create(hasattr(self, "defaults")) return super().call_obj_hasattr(tx, name) - def has_self(self): + def has_self(self) -> bool: return False - def get_globals(self): + def get_globals(self) -> dict[str, Any]: return self.f_globals - def bind_args(self, parent, args, kwargs): + def bind_args( + self, + parent: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> dict[str, VariableTracker]: code = self.get_code() func = types.FunctionType( code, self.f_globals, self.fn_name.as_python_constant(), - tuple(self.defaults.items) if self.defaults else None, + tuple(self.defaults.items) if self.defaults else None, # type: ignore[attr-defined] tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), ) if self.kwdefaults: - func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() + func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() # type: ignore[attr-defined] bound = inspect.signature(func).bind(*args, **kwargs) bound.apply_defaults() result = dict(bound.arguments.items()) - wrap_args_kwargs(parent.output.root_tx, result) + wrap_args_kwargs(parent.output.root_tx, result) # type: ignore[arg-type] init_cellvars(parent, result, code) for idx, name in enumerate(code.co_freevars): assert name not in result - cell = self.closure.items[idx] + cell = self.closure.items[idx] # type: ignore[attr-defined, union-attr] result[name] = cell return result - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) codegen(self.code) codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)]) - codegen(ConstantVariable.create(self.code.value.co_name)) + codegen(ConstantVariable.create(self.code.value.co_name)) # type: ignore[attr-defined] if self.defaults: codegen(self.defaults) @@ -1426,7 +1523,12 @@ def reconstruct(self, codegen: "PyCodegen"): class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: Any, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("fn_name", None) kwargs.pop("code", None) kwargs.pop("f_globals", None) @@ -1451,16 +1553,16 @@ def __init__(self, wrapped, context, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1472,16 +1574,16 @@ class SkipFunctionVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, value, reason=None, **kwargs) -> None: + def __init__(self, value: Any, reason: Optional[str] = None, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value self.reason = reason - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value @classmethod - def create_with_source(cls, value, source): + def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": # Use closure match guard (i.e. guard on __code__ object instead of # function id) to avoid guarding on nested functions. if inspect.getattr_static(value, "_torchdynamo_disable", False): @@ -1510,9 +1612,9 @@ def create_with_source(cls, value, source): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if inspect.getattr_static(self.value, "_torchdynamo_disable", False): msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) unimplemented_v2( @@ -1525,7 +1627,7 @@ def call_function( ], ) elif self.value is torch._dynamo.graph_break: - graph_break_msg = kwargs.get("msg", None) + graph_break_msg = kwargs.get("msg") if graph_break_msg: graph_break_msg = graph_break_msg.as_python_constant() unimplemented_v2( @@ -1537,7 +1639,7 @@ def call_function( ], ) elif self.value is torch._dynamo.skip_frame: - skip_frame_msg = kwargs.get("msg", None) + skip_frame_msg = kwargs.get("msg") if skip_frame_msg: skip_frame_msg = skip_frame_msg.as_python_constant() raise SkipFrame( @@ -1629,10 +1731,12 @@ def call_function( hints=hints, ) - def call_obj_hasattr(self, tx: "InstructionTranslator", name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: return variables.ConstantVariable.create(hasattr(self.value, name)) - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: return variables.GetAttrVariable(self, name) @@ -1640,26 +1744,31 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): class WrappedSkipFunctionVariable(SkipFunctionVariable): - def __init__(self, wrapped, context, **kwargs) -> None: + def __init__( + self, + wrapped: VariableTracker, + context: "ContextWrappingVariable", + **kwargs: Any, + ) -> None: kwargs.pop("value", None) kwargs.pop("reason", None) - super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) # type: ignore[attr-defined] self.wrapped = wrapped self.context = context def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) return result - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen(self.context)) + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null(lambda: codegen(self.context)) # type: ignore[arg-type] codegen(self.wrapped) codegen.extend_output(create_call_function(1, False)) @@ -1672,12 +1781,12 @@ class WrapperUserFunctionVariable(VariableTracker): __script_if_tracing_wrapper have the original attr at "__original_fn". """ - def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: + def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.wrapper_obj = wrapper_obj self.attr_to_trace = attr_to_trace - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == self.attr_to_trace: val = getattr(self.wrapper_obj, self.attr_to_trace) source = self.source and AttrSource(self.source, name) @@ -1685,15 +1794,15 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [] def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if hasattr(self.wrapper_obj, "cache_info"): target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None) module_name = getattr(target_fn, "__module__", "") or "" @@ -1719,9 +1828,9 @@ def call_function( user_stack_trace += str(user_stack_formatted) dynamo_logger.debug(user_stack_trace) - all_args = self.self_args() + args + all_args = self.self_args() + list(args) return variables.UserFunctionVariable( - polyfills.getattr_and_trace + polyfills.getattr_and_trace # type: ignore[arg-type] ).call_function( tx, [self, variables.ConstantVariable(self.attr_to_trace), *all_args], @@ -1736,15 +1845,21 @@ class WrapperUserMethodVariable(WrapperUserFunctionVariable): WrapperUserFunctionVariable in `call_function` method. """ - def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None: + def __init__( + self, + wrapper_obj: Any, + attr_to_trace: str, + self_obj: VariableTracker, + **kwargs: Any, + ) -> None: super().__init__(wrapper_obj, attr_to_trace, **kwargs) self.obj = self_obj - def self_args(self): + def self_args(self) -> list[VariableTracker]: return [self.obj] -def _traceable_collective_remaps(): +def _traceable_collective_remaps() -> dict[Any, Any]: # We can't rely on importing from distributed, since it's not always built if torch.distributed.is_available(): from torch.distributed._functional_collectives import ( @@ -1755,7 +1870,9 @@ def _traceable_collective_remaps(): return {} -def _traceable_collectives_source(tx: "InstructionTranslator", fn): +def _traceable_collectives_source( + tx: "InstructionTranslator", fn: Callable[..., Any] +) -> AttrSource: assert torch.distributed.is_available(), "Illegal invocation." assert fn in _traceable_collective_remaps().values() @@ -1775,13 +1892,24 @@ class CollectiveFunctionRewriteVariable(UserFunctionVariable): than status-quo as we currently graph-break on all distributed.* collectives. """ - def __init__(self, fn, *, replacement_var, **kwargs) -> None: - super().__init__(fn, **kwargs) + def __init__( + self, + fn: Callable[..., Any], + *, + replacement_var: UserFunctionVariable, + **kwargs: Any, + ) -> None: + super().__init__(fn, **kwargs) # type: ignore[arg-type] assert isinstance(replacement_var, UserFunctionVariable) self.replacement_var = replacement_var @staticmethod - def create(tx: "InstructionTranslator", old_fn, source, **options): + def create( + tx: "InstructionTranslator", + old_fn: Callable[..., Any], + source: Source, + **options: Any, + ) -> "CollectiveFunctionRewriteVariable": new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) return CollectiveFunctionRewriteVariable( old_fn, @@ -1791,22 +1919,24 @@ def create(tx: "InstructionTranslator", old_fn, source, **options): ) @staticmethod - def can_rewrite(variable): + def can_rewrite(variable: Any) -> bool: return ( inspect.isfunction(variable) and variable in _traceable_collective_remaps() ) @staticmethod - def rewrite(tx: "InstructionTranslator", fn): + def rewrite( + tx: "InstructionTranslator", fn: Callable[..., Any] + ) -> tuple[Any, AttrSource]: new_fn = _traceable_collective_remaps()[fn] return new_fn, _traceable_collectives_source(tx, new_fn) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # call_function must check any unsupported arguments and graph-break. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, # since that's the contract for putting a mapping in `traceable_collective_remaps` @@ -1836,7 +1966,7 @@ def call_function( ): reduce_op_var = kwargs.get("op") reduce_op = ( - reduce_op_var.value + reduce_op_var.value # type: ignore[attr-defined] if reduce_op_var is not None else signature.parameters["op"].default ) @@ -1852,12 +1982,12 @@ class FunctoolsWrapsVariable(UserFunctionVariable): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if not kwargs and len(args) == 1: - def wraps(fn): + def wraps(fn: Any) -> VariableTracker: if isinstance(fn, variables.NestedUserFunctionVariable): return fn.clone(wrapped_fn=args[0]) unimplemented_v2( @@ -1875,15 +2005,15 @@ def wraps(fn): class CollectionsNamedTupleFunction(UserFunctionVariable): - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.fn def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: constant_args = check_constant_args(args, kwargs) if constant_args: try: @@ -1898,7 +2028,9 @@ def call_function( args=list(map(ConstantVariable.create, exc.args)), ) return variables.UserDefinedClassVariable( - value, mutation_type=ValueMutationNew() + # pyrefly: ignore[unbound-name] + value, + mutation_type=ValueMutationNew(), ) unimplemented_v2( gb_type="namedtuple construction", @@ -1911,7 +2043,13 @@ def call_function( class FunctoolsPartialVariable(VariableTracker): - def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: + def __init__( + self, + func: VariableTracker, + args: Sequence[VariableTracker], + keywords: dict[str, VariableTracker], + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.func = func assert isinstance(args, list) @@ -1922,10 +2060,10 @@ def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: # on it is sufficient for the tracing purposes. self.fake_value = functools.partial(identity) - def python_type(self): + def python_type(self) -> type: return functools.partial - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -1940,16 +2078,16 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False) ) - def get_function(self): + def get_function(self) -> Any: return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - merged_args = self.args + args + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + merged_args = self.args + list(args) merged_kwargs = {**self.keywords, **kwargs} return self.func.call_function(tx, merged_args, merged_kwargs) @@ -1961,7 +2099,7 @@ def call_obj_hasattr( hasattr(functools.partial(identity), name) ) - def var_getattr(self, tx: "InstructionTranslator", name: str): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: source = self.source and AttrSource(self.source, name) # Handle __slots__ if name == "func": @@ -1975,14 +2113,14 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): return variables.GetAttrVariable(self, name) raise_observed_exception(AttributeError, tx) - def as_python_constant(self): + def as_python_constant(self) -> Any: return functools.partial( self.func.as_python_constant(), *[arg.as_python_constant() for arg in self.args], **{k: v.as_python_constant() for k, v in self.keywords.items()}, ) - def guard_as_python_constant(self): + def guard_as_python_constant(self) -> Any: """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" return functools.partial( self.func.guard_as_python_constant(), @@ -2005,16 +2143,20 @@ def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]: return {} @classmethod - def create_with_source(cls, value, source): + def create_with_source( + cls, value: Any, source: Source + ) -> "PolyfilledFunctionVariable": install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) return cls(value, source=source) - def __init__(self, fn: _F, **kwargs) -> None: + def __init__(self, fn: _F, **kwargs: Any) -> None: super().__init__(**kwargs) + # pyrefly: ignore[invalid-type-var] self.fn: _F = fn handler = self._get_polyfill_handlers().get(fn, fn) + traceable_fn = None assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}" for candidate_attr in ( "__torch_dynamo_polyfill__", # registered polyfill @@ -2029,28 +2171,29 @@ def __init__(self, fn: _F, **kwargs) -> None: raise RuntimeError( f"Polyfill handler {handler} does not have a traceable function" ) - - self.wrapped_fn: _F = handler + # pyrefly: ignore[invalid-type-var] + self.wrapped_fn = handler + # pyrefly: ignore[invalid-type-var] self.traceable_fn: _F = traceable_fn @property - def polyfill_fn(self) -> _F: + def polyfill_fn(self) -> Callable[..., Any]: return self.traceable_fn - def can_constant_fold_through(self): + def can_constant_fold_through(self) -> bool: return getattr( self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False ) - def get_function(self): + def get_function(self) -> Any: return self.as_python_constant() def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -2087,7 +2230,7 @@ def call_function( ( x.value if isinstance(x, variables.ConstantVariable) - else x.sym_num + else x.sym_num # type: ignore[attr-defined] ) for x in args[0].items ] @@ -2099,11 +2242,11 @@ def call_function( def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__call__": return self.call_function(tx, args, kwargs) @@ -2113,27 +2256,33 @@ def call_method( options = {} if self.source: options["source"] = AttrSource(self.source, name) + # pyrefly: ignore[bad-specialization] polyfilled_method_variable = PolyfilledFunctionVariable(method, **options) return polyfilled_method_variable.call_function(tx, args, kwargs) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.fn class TracebackVariable(VariableTracker): # We don't track traceback. A call to any function in this module is a no-op - def call_function(self, tx, args, kwargs): ... + def call_function( # type: ignore[empty-body] + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ... class SysFunctionVariable(VariableTracker): - def __init__(self, value, **kwargs): + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value - def exc_info(self, tx): + def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": if len(tx.exn_vt_stack): exn = tx.exn_vt_stack[-1] - typ = exn.exc_type + typ = exn.exc_type # type: ignore[union-attr] tb = None items = [ VariableTracker.build(tx, typ), @@ -2146,12 +2295,17 @@ def exc_info(self, tx): variables.ConstantVariable(None), variables.ConstantVariable(None), ] - return variables.TupleVariable(items) + return variables.TupleVariable(items) # type: ignore[arg-type] - def exception(self, tx): + def exception(self, tx: "InstructionTranslator") -> VariableTracker: return self.exc_info(tx).items[1] - def call_function(self, tx, args, kwargs): + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.value is sys.exc_info: return self.exc_info(tx) assert self.value is sys.exception @@ -2170,15 +2324,15 @@ class DynamoTritonHOPifier(TritonHOPifier): def raise_unsupported(self, msg: str) -> Never: raise Unsupported(msg) - def is_callable(self, maybe_callable: Any) -> bool: + def is_callable(self, maybe_callable: VariableTracker) -> bool: return isinstance( maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) ) - def get_value(self, val: Any) -> Any: - return val.value + def get_value(self, val: VariableTracker) -> Any: + return val.value # type: ignore[attr-defined] - def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: + def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, ...]: from .lists import BaseListVariable if isinstance(grid, BaseListVariable): @@ -2193,20 +2347,35 @@ def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]: ], ) - def call_grid(self, grid, meta, tx): - meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()} - grid = grid.call_function(tx, [meta], {}) + def call_grid( + self, grid: Any, meta: dict[str, Any], tx: "InstructionTranslator" + ) -> Any: + meta_var = {variables.ConstantVariable.create(k): v for k, v in meta.items()} + grid = grid.call_function(tx, [meta_var], {}) return grid # We use this function to wrap call_prune_configs - def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable): + def call_user_defined_fn( + self, + user_fn: Callable[..., Any], + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + tx: Optional["InstructionTranslator"], + variable: Any, + ) -> VariableTracker: from .builder import SourcelessBuilder - wrapped_user_function = SourcelessBuilder.create(tx, user_fn) + wrapped_user_function = SourcelessBuilder.create(tx, user_fn) # type: ignore[arg-type] result = wrapped_user_function.call_function(tx, args, kwargs) return result - def wrap_user_defined_obj(self, user_obj, tx, variable, name): + def wrap_user_defined_obj( + self, + user_obj: Any, + tx: Optional["InstructionTranslator"], + variable: Any, + name: str, + ) -> VariableTracker: from .builder import VariableBuilder wrapped_user_obj = VariableBuilder( @@ -2214,7 +2383,9 @@ def wrap_user_defined_obj(self, user_obj, tx, variable, name): )._wrap(user_obj) return wrapped_user_obj - def maybe_unpack_configs(self, configs, tx): + def maybe_unpack_configs( + self, configs: Any, tx: Optional["InstructionTranslator"] + ) -> list[Any]: # unpack the list of configs configs = configs.unpack_var_sequence(tx) @@ -2223,7 +2394,7 @@ def maybe_unpack_configs(self, configs, tx): return configs - def maybe_unpack_heuristic_result(self, result: Any) -> Any: + def maybe_unpack_heuristic_result(self, result: VariableTracker) -> Any: if not result.is_python_constant(): self.raise_unsupported( "@triton.heuristics must return constant values because configs can only contain constant values." @@ -2233,7 +2404,7 @@ def maybe_unpack_heuristic_result(self, result: Any) -> Any: # We need to override call_getitem here so that we can add the source in the case # where we call the triton kernel with a grid - def call_getitem( + def call_getitem( # type: ignore[override] self, variable: "TritonKernelVariable", args: Sequence[Any], @@ -2251,7 +2422,13 @@ def call_getitem( kernel_source=variable.source, ) - def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: + def call_HOP( + self, + variable: "TritonKernelVariable", + grids: Any, + combined_args_raw: dict[str, Any], + tx: "InstructionTranslator", + ) -> "variables.ConstantVariable": from .constant import ConstantVariable from .dicts import ConstDictVariable @@ -2330,7 +2507,9 @@ class TritonKernelVariable(VariableTracker): kernel_idx: Optional[int] kernel_source: "AttrSource" - def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: + def __init__( + self, kernel: Any, kernel_idx: Optional[int], grid: Any, **kwargs: Any + ) -> None: self.kernel_source = kwargs.pop("kernel_source", None) super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) @@ -2338,24 +2517,24 @@ def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": - return dynamo_triton_hopifier_singleton.call_triton_kernel( + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return dynamo_triton_hopifier_singleton.call_triton_kernel( # type: ignore[return-value] self, args, kwargs, tx ) def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": return dynamo_triton_hopifier_singleton.call_getitem(self, args) elif name == "run": - return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) + return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] # Bail out to parent's implementation return super().call_method(tx, name, args, kwargs) @@ -2374,11 +2553,11 @@ class TMADescriptorExperimentalVariable(VariableTracker): def __init__( self, data_ptr: "variables.DataPtrVariable", - dims: "list[ConstantVariable]", - block_dims: "list[ConstantVariable]", - element_size: "ConstantVariable", - **kwargs, - ): + dims: list[VariableTracker], + block_dims: list[VariableTracker], + element_size: VariableTracker, + **kwargs: Any, + ) -> None: assert isinstance(data_ptr, variables.DataPtrVariable) super().__init__(**kwargs) self.data_ptr = data_ptr @@ -2386,14 +2565,14 @@ def __init__( self.block_dims = block_dims self.element_size = element_size - def to_metadata(self): + def to_metadata(self) -> Any: return create_tma_experimental_metadata( [dim.as_proxy() for dim in self.dims], [dim.as_proxy() for dim in self.block_dims], self.element_size.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", @@ -2405,28 +2584,28 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.foreach(args) codegen.call_function(len(args) + 1, False) - def get_tensor(self): + def get_tensor(self) -> VariableTracker: return self.data_ptr.from_tensor class TMADescriptorStableVariable(VariableTracker): def __init__( self, - tensor: "variables.TensorVariable", - block_shape: "variables.ListVariable", - **kwargs, - ): + tensor: "TensorVariable", + block_shape: "ListVariable", + **kwargs: Any, + ) -> None: assert isinstance(tensor, variables.TensorVariable) super().__init__(**kwargs) self.tensor = tensor self.block_shape = block_shape - def to_metadata(self): + def to_metadata(self) -> Any: return create_tma_stable_metadata( self.block_shape.as_proxy(), ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.tensor_descriptor", @@ -2438,7 +2617,7 @@ def reconstruct(self, codegen: "PyCodegen"): codegen(self.block_shape) codegen.call_method(2) - def get_tensor(self) -> "variables.TensorVariable": + def get_tensor(self) -> Any: return self.tensor @@ -2446,7 +2625,7 @@ class CreateTMADescriptorExperimentalVariable(VariableTracker): def __init__( self, rank: int, - **kwargs, + **kwargs: Any, ) -> None: assert rank in (1, 2) super().__init__(**kwargs) @@ -2455,9 +2634,9 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] if not isinstance(ptr, variables.DataPtrVariable): @@ -2507,13 +2686,13 @@ class CreateTMADescriptorStableVariable(VariableTracker): def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: tensor = kwargs["tensor"] if "tensor" in kwargs else args[0] block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1] return TMADescriptorStableVariable( - tensor=tensor, - block_shape=block_shape, + tensor=tensor, # type: ignore[arg-type] + block_shape=block_shape, # type: ignore[arg-type] ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index ecad58920d7c2..624844382d53a 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -590,7 +590,7 @@ def _next() -> VariableTracker: else: res = self.fn.call_function(tx, [item], {}) pred_res = variables.UserFunctionVariable( - polyfills.predicate + polyfills.predicate # type: ignore[arg-type] ).call_function(tx, [res], {}) if pred_res.as_python_constant(): return item diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index e4731697868e5..d6c005bccbda3 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -1498,6 +1498,7 @@ def check_and_create_method() -> Optional[VariableTracker]: variables.UserDefinedClassVariable(self.tuple_cls), ) elif isinstance(method, staticmethod): + # pyrefly: ignore[bad-argument-type] return UserFunctionVariable(method.__func__) elif inspect.isfunction(method): return UserMethodVariable(method, self) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e8e246be968eb..a687d77c186db 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -472,7 +472,12 @@ def call_function( ) elif self.value is torch.nn.attention.sdpa_kernel.__wrapped__: # type: ignore[attr-defined] name_to_arg_map = bind_args_cached( - self.value, tx, self.source, args, kwargs + # pyrefly: ignore[bad-argument-type] + self.value, + tx, + self.source, + args, + kwargs, ) backends = name_to_arg_map["backends"].as_python_constant() set_priority = name_to_arg_map["set_priority"].as_python_constant() @@ -1429,7 +1434,7 @@ def call_function( packed_input_vt = TupleVariable.build( tx, (TupleVariable.build(tx, args), ConstDictVariable.build(tx, kwargs)) ) - out_vt = variables.UserFunctionVariable(tree_flatten).call_function( + out_vt = variables.UserFunctionVariable(tree_flatten).call_function( # type: ignore[arg-type] tx, [packed_input_vt], {} ) assert isinstance(out_vt, TupleVariable) and len(out_vt.items) == 2 From 4e277e63231799257902a217a455aa8c97971ce5 Mon Sep 17 00:00:00 2001 From: Colin L Reliability Rice Date: Fri, 7 Nov 2025 20:49:31 +0000 Subject: [PATCH 225/651] inductor: compile_worker - Fix potential race condition with quiesce waitcounters (#167025) Summary: If quiesce ends up called twice (which is likely not possible with the timer based implementation, but possible with either manual calls, or with the context manager implementation), this assertion fires. Instead make this assertion tolerant to rentrant calling of quiesce Test Plan: Added a explicit test which calls quiesce twice. Differential Revision: D86251534 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167025 Approved by: https://github.com/masnesral --- test/inductor/test_compile_worker.py | 17 +++++++++++++++++ torch/_inductor/compile_worker/subproc_pool.py | 10 +++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 7237d5a01c6b2..270b15fdf49d8 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -73,6 +73,23 @@ def test_quiesce(self): finally: pool.shutdown() + @skipIfWindows(msg="pass_fds not supported on Windows.") + def test_quiesce_repeatedly(self): + pool = SubprocPool(2) + try: + a = pool.submit(operator.add, 100, 1) + pool.quiesce() + pool.wakeup() + b = pool.submit(operator.sub, 100, 1) + pool.quiesce() + pool.quiesce() + pool.wakeup() + b = pool.submit(operator.sub, 100, 1) + self.assertEqual(a.result(), 101) + self.assertEqual(b.result(), 99) + finally: + pool.shutdown() + @skipIfWindows(msg="pass_fds not supported on Windows.") def test_logging(self): os.environ["MAST_HPC_JOB_NAME"] = "test_job" diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index a4114644026ca..2bc87e6f3eb95 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -319,11 +319,11 @@ def _read_thread(self) -> None: def quiesce(self) -> None: self._send(MsgHeader.QUIESCE) - assert self.quiesce_waitcounter is None - self.quiesce_waitcounter = _WaitCounter( - "pytorch.wait_counter.subproc_pool.running" - ).guard() - self.quiesce_waitcounter.__enter__() + if self.quiesce_waitcounter is None: + self.quiesce_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.running" + ).guard() + self.quiesce_waitcounter.__enter__() def wakeup(self) -> None: self._send(MsgHeader.WAKEUP) From fb9e10fe255732c7ba0590099a85326d5ad51d8e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 20:53:14 +0000 Subject: [PATCH 226/651] Revert "Update pythoncapi_compat.h (#167138)" This reverts commit c90a976370945af052bb7b0db86240fa6f321cd6. Reverted https://github.com/pytorch/pytorch/pull/167138 on behalf of https://github.com/donigian due to Sorry but this is breaking internally. See diff D86458778 for details. ([comment](https://github.com/pytorch/pytorch/pull/167138#issuecomment-3504895388)) --- torch/csrc/utils/python_compat.h | 8 + torch/csrc/utils/pythoncapi_compat.h | 1420 +------------------------- 2 files changed, 20 insertions(+), 1408 deletions(-) diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 8488d5d0917b5..16308dad4421d 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -33,6 +33,14 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) { #endif } +// Provided by CPython but getting the header for them is very hard +#if IS_PYTHON_3_11_PLUS +// NOLINTNEXTLINE(readability-redundant-declaration) +PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self); +#else +extern void _PyWeakref_ClearRef(PyWeakReference* self); +#endif + #ifdef __cplusplus } #endif diff --git a/torch/csrc/utils/pythoncapi_compat.h b/torch/csrc/utils/pythoncapi_compat.h index bb45c18531106..05e80b5ee8607 100644 --- a/torch/csrc/utils/pythoncapi_compat.h +++ b/torch/csrc/utils/pythoncapi_compat.h @@ -7,7 +7,7 @@ // https://github.com/python/pythoncapi_compat // // Latest version: -// https://raw.githubusercontent.com/python/pythoncapi-compat/main/pythoncapi_compat.h +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h // // SPDX-License-Identifier: 0BSD @@ -19,15 +19,11 @@ extern "C" { #endif #include -#include // offsetof() // Python 3.11.0b4 added PyFrame_Back() to Python.h #if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) # include "frameobject.h" // PyFrameObject, PyFrame_GetBack() #endif -#if PY_VERSION_HEX < 0x030C00A3 -# include // T_SHORT, READONLY -#endif #ifndef _Py_CAST @@ -37,13 +33,11 @@ extern "C" { // Static inline functions should use _Py_NULL rather than using directly NULL // to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, // _Py_NULL is defined as nullptr. -#ifndef _Py_NULL -# if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ - || (defined(__cplusplus) && __cplusplus >= 201103) -# define _Py_NULL nullptr -# else -# define _Py_NULL NULL -# endif +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ + || (defined(__cplusplus) && __cplusplus >= 201103) +# define _Py_NULL nullptr +#else +# define _Py_NULL NULL #endif // Cast argument to PyObject* type. @@ -51,13 +45,6 @@ extern "C" { # define _PyObject_CAST(op) _Py_CAST(PyObject*, op) #endif -#ifndef Py_BUILD_ASSERT -# define Py_BUILD_ASSERT(cond) \ - do { \ - (void)sizeof(char [1 - 2 * !(cond)]); \ - } while(0) -#endif - // bpo-42262 added Py_NewRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) @@ -81,16 +68,6 @@ static inline PyObject* _Py_XNewRef(PyObject *obj) #endif -// bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 -#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) -static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) -{ - ob->ob_refcnt = refcnt; -} -#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) -#endif - - // Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. // It is excluded from the limited C API. #if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) @@ -127,37 +104,6 @@ static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) # define Py_IsFalse(x) Py_Is(x, Py_False) #endif - -// bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 -#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) -static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) -{ - ob->ob_type = type; -} -#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) -#endif - - -// bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 -#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) -static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) -{ - ob->ob_size = size; -} -#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) -#endif - - -// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 -#if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) -static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) -{ - assert(frame != _Py_NULL); - assert(frame->f_code != _Py_NULL); - return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); -} -#endif - static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) { PyCodeObject *code = PyFrame_GetCode(frame); @@ -166,15 +112,6 @@ static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) } -// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 -#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) -static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) -{ - assert(frame != _Py_NULL); - return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); -} -#endif - #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) { @@ -292,26 +229,6 @@ PyFrame_GetVarString(PyFrameObject *frame, const char *name) #endif -// bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 -#if PY_VERSION_HEX < 0x030900A5 || (defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030B0000) -static inline PyInterpreterState * -PyThreadState_GetInterpreter(PyThreadState *tstate) -{ - assert(tstate != _Py_NULL); - return tstate->interp; -} -#endif - - -// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 -#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) -static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) -{ - assert(tstate != _Py_NULL); - return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); -} -#endif - #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyThreadState_GetFrameBorrow(PyThreadState *tstate) @@ -323,35 +240,6 @@ _PyThreadState_GetFrameBorrow(PyThreadState *tstate) #endif -// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 -#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) -static inline PyInterpreterState* PyInterpreterState_Get(void) -{ - PyThreadState *tstate; - PyInterpreterState *interp; - - tstate = PyThreadState_GET(); - if (tstate == _Py_NULL) { - Py_FatalError("GIL released (tstate is NULL)"); - } - interp = tstate->interp; - if (interp == _Py_NULL) { - Py_FatalError("no current interpreter"); - } - return interp; -} -#endif - - -// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 -#if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) -static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) -{ - assert(tstate != _Py_NULL); - return tstate->id; -} -#endif - // bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 #if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) static inline void PyThreadState_EnterTracing(PyThreadState *tstate) @@ -381,27 +269,6 @@ static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) #endif -// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 -// PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 -#if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 -static inline PyObject* PyObject_CallNoArgs(PyObject *func) -{ - return PyObject_CallFunctionObjArgs(func, NULL); -} -#endif - - -// bpo-39245 made PyObject_CallOneArg() public (previously called -// _PyObject_CallOneArg) in Python 3.9.0a4 -// PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 -#if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 -static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) -{ - return PyObject_CallFunctionObjArgs(func, arg, NULL); -} -#endif - - // bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 static inline int @@ -427,58 +294,6 @@ PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) #endif -// bpo-40024 added PyModule_AddType() to Python 3.9.0a5 -#if PY_VERSION_HEX < 0x030900A5 -static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) -{ - const char *name, *dot; - - if (PyType_Ready(type) < 0) { - return -1; - } - - // inline _PyType_Name() - name = type->tp_name; - assert(name != _Py_NULL); - dot = strrchr(name, '.'); - if (dot != _Py_NULL) { - name = dot + 1; - } - - return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); -} -#endif - - -// bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. -// bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. -#if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) -static inline int PyObject_GC_IsTracked(PyObject* obj) -{ - return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); -} -#endif - -// bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. -// bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. -#if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) -static inline int PyObject_GC_IsFinalized(PyObject *obj) -{ - PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; - return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); -} -#endif - - -// bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 -#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) -static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { - return Py_TYPE(ob) == type; -} -#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) -#endif - - // bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. // bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. // Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal @@ -586,7 +401,7 @@ static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) return 0; } *pobj = Py_NewRef(obj); - return 1; + return (*pobj != NULL); } #endif @@ -605,81 +420,6 @@ static inline Py_ssize_t PyVectorcall_NARGS(size_t n) #endif -// gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 -#if PY_VERSION_HEX < 0x030900A4 -static inline PyObject* -PyObject_Vectorcall(PyObject *callable, PyObject *const *args, - size_t nargsf, PyObject *kwnames) -{ -#if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) - // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 - return _PyObject_Vectorcall(callable, args, nargsf, kwnames); -#else - PyObject *posargs = NULL, *kwargs = NULL; - PyObject *res; - Py_ssize_t nposargs, nkwargs, i; - - if (nargsf != 0 && args == NULL) { - PyErr_BadInternalCall(); - goto error; - } - if (kwnames != NULL && !PyTuple_Check(kwnames)) { - PyErr_BadInternalCall(); - goto error; - } - - nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); - if (kwnames) { - nkwargs = PyTuple_GET_SIZE(kwnames); - } - else { - nkwargs = 0; - } - - posargs = PyTuple_New(nposargs); - if (posargs == NULL) { - goto error; - } - if (nposargs) { - for (i=0; i < nposargs; i++) { - PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); - args++; - } - } - - if (nkwargs) { - kwargs = PyDict_New(); - if (kwargs == NULL) { - goto error; - } - - for (i = 0; i < nkwargs; i++) { - PyObject *key = PyTuple_GET_ITEM(kwnames, i); - PyObject *value = *args; - args++; - if (PyDict_SetItem(kwargs, key, value) < 0) { - goto error; - } - } - } - else { - kwargs = NULL; - } - - res = PyObject_Call(callable, posargs, kwargs); - Py_DECREF(posargs); - Py_XDECREF(kwargs); - return res; - -error: - Py_DECREF(posargs); - Py_XDECREF(kwargs); - return NULL; -#endif -} -#endif - - // gh-106521 added PyObject_GetOptionalAttr() and // PyObject_GetOptionalAttrString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 @@ -924,7 +664,7 @@ static inline int PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (dict == NULL || *dict == NULL) { + if (*dict == NULL) { return -1; } Py_VISIT(*dict); @@ -935,7 +675,7 @@ static inline void PyObject_ClearManagedDict(PyObject *obj) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (dict == NULL || *dict == NULL) { + if (*dict == NULL) { return; } Py_CLEAR(*dict); @@ -1210,11 +950,11 @@ static inline int PyTime_PerfCounter(PyTime_t *result) #endif // gh-111389 added hash constants to Python 3.13.0a5. These constants were -// added first as private macros to Python 3.4.0b1 and PyPy 7.3.8. +// added first as private macros to Python 3.4.0b1 and PyPy 7.3.9. #if (!defined(PyHASH_BITS) \ && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ - && PYPY_VERSION_NUM >= 0x07030800))) + && PYPY_VERSION_NUM >= 0x07090000))) # define PyHASH_BITS _PyHASH_BITS # define PyHASH_MODULUS _PyHASH_MODULUS # define PyHASH_INF _PyHASH_INF @@ -1456,18 +1196,6 @@ PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, return res; } -static inline int -PyUnicodeWriter_WriteASCII(PyUnicodeWriter *writer, - const char *str, Py_ssize_t size) -{ - if (size < 0) { - size = (Py_ssize_t)strlen(str); - } - - return _PyUnicodeWriter_WriteASCIIString((_PyUnicodeWriter*)writer, - str, size); -} - static inline int PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, const wchar_t *str, Py_ssize_t size) @@ -1491,8 +1219,7 @@ PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, Py_ssize_t start, Py_ssize_t end) { if (!PyUnicode_Check(str)) { - PyErr_Format(PyExc_TypeError, "expect str, not %s", - Py_TYPE(str)->tp_name); + PyErr_Format(PyExc_TypeError, "expect str, not %T", str); return -1; } if (start < 0 || start > end) { @@ -1539,1129 +1266,6 @@ static inline int PyLong_GetSign(PyObject *obj, int *sign) } #endif -// gh-126061 added PyLong_IsPositive/Negative/Zero() to Python in 3.14.0a2 -#if PY_VERSION_HEX < 0x030E00A2 -static inline int PyLong_IsPositive(PyObject *obj) -{ - if (!PyLong_Check(obj)) { - PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); - return -1; - } - return _PyLong_Sign(obj) == 1; -} - -static inline int PyLong_IsNegative(PyObject *obj) -{ - if (!PyLong_Check(obj)) { - PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); - return -1; - } - return _PyLong_Sign(obj) == -1; -} - -static inline int PyLong_IsZero(PyObject *obj) -{ - if (!PyLong_Check(obj)) { - PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); - return -1; - } - return _PyLong_Sign(obj) == 0; -} -#endif - - -// gh-124502 added PyUnicode_Equal() to Python 3.14.0a0 -#if PY_VERSION_HEX < 0x030E00A0 -static inline int PyUnicode_Equal(PyObject *str1, PyObject *str2) -{ - if (!PyUnicode_Check(str1)) { - PyErr_Format(PyExc_TypeError, "first argument must be str, not %s", - Py_TYPE(str1)->tp_name); - return -1; - } - if (!PyUnicode_Check(str2)) { - PyErr_Format(PyExc_TypeError, "second argument must be str, not %s", - Py_TYPE(str2)->tp_name); - return -1; - } - -#if PY_VERSION_HEX >= 0x030d0000 && !defined(PYPY_VERSION) - PyAPI_FUNC(int) _PyUnicode_Equal(PyObject *str1, PyObject *str2); - - return _PyUnicode_Equal(str1, str2); -#elif PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) - return _PyUnicode_EQ(str1, str2); -#elif PY_VERSION_HEX >= 0x03090000 && defined(PYPY_VERSION) - return _PyUnicode_EQ(str1, str2); -#else - return (PyUnicode_Compare(str1, str2) == 0); -#endif -} -#endif - - -// gh-121645 added PyBytes_Join() to Python 3.14.0a0 -#if PY_VERSION_HEX < 0x030E00A0 -static inline PyObject* PyBytes_Join(PyObject *sep, PyObject *iterable) -{ - return _PyBytes_Join(sep, iterable); -} -#endif - - -#if PY_VERSION_HEX < 0x030E00A0 -static inline Py_hash_t Py_HashBuffer(const void *ptr, Py_ssize_t len) -{ -#if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) - PyAPI_FUNC(Py_hash_t) _Py_HashBytes(const void *src, Py_ssize_t len); - - return _Py_HashBytes(ptr, len); -#else - Py_hash_t hash; - PyObject *bytes = PyBytes_FromStringAndSize((const char*)ptr, len); - if (bytes == NULL) { - return -1; - } - hash = PyObject_Hash(bytes); - Py_DECREF(bytes); - return hash; -#endif -} -#endif - - -#if PY_VERSION_HEX < 0x030E00A0 -static inline int PyIter_NextItem(PyObject *iter, PyObject **item) -{ - iternextfunc tp_iternext; - - assert(iter != NULL); - assert(item != NULL); - - tp_iternext = Py_TYPE(iter)->tp_iternext; - if (tp_iternext == NULL) { - *item = NULL; - PyErr_Format(PyExc_TypeError, "expected an iterator, got '%s'", - Py_TYPE(iter)->tp_name); - return -1; - } - - if ((*item = tp_iternext(iter))) { - return 1; - } - if (!PyErr_Occurred()) { - return 0; - } - if (PyErr_ExceptionMatches(PyExc_StopIteration)) { - PyErr_Clear(); - return 0; - } - return -1; -} -#endif - - -#if PY_VERSION_HEX < 0x030E00A0 -static inline PyObject* PyLong_FromInt32(int32_t value) -{ - Py_BUILD_ASSERT(sizeof(long) >= 4); - return PyLong_FromLong(value); -} - -static inline PyObject* PyLong_FromInt64(int64_t value) -{ - Py_BUILD_ASSERT(sizeof(long long) >= 8); - return PyLong_FromLongLong(value); -} - -static inline PyObject* PyLong_FromUInt32(uint32_t value) -{ - Py_BUILD_ASSERT(sizeof(unsigned long) >= 4); - return PyLong_FromUnsignedLong(value); -} - -static inline PyObject* PyLong_FromUInt64(uint64_t value) -{ - Py_BUILD_ASSERT(sizeof(unsigned long long) >= 8); - return PyLong_FromUnsignedLongLong(value); -} - -static inline int PyLong_AsInt32(PyObject *obj, int32_t *pvalue) -{ - Py_BUILD_ASSERT(sizeof(int) == 4); - int value = PyLong_AsInt(obj); - if (value == -1 && PyErr_Occurred()) { - return -1; - } - *pvalue = (int32_t)value; - return 0; -} - -static inline int PyLong_AsInt64(PyObject *obj, int64_t *pvalue) -{ - Py_BUILD_ASSERT(sizeof(long long) == 8); - long long value = PyLong_AsLongLong(obj); - if (value == -1 && PyErr_Occurred()) { - return -1; - } - *pvalue = (int64_t)value; - return 0; -} - -static inline int PyLong_AsUInt32(PyObject *obj, uint32_t *pvalue) -{ - Py_BUILD_ASSERT(sizeof(long) >= 4); - unsigned long value = PyLong_AsUnsignedLong(obj); - if (value == (unsigned long)-1 && PyErr_Occurred()) { - return -1; - } -#if SIZEOF_LONG > 4 - if ((unsigned long)UINT32_MAX < value) { - PyErr_SetString(PyExc_OverflowError, - "Python int too large to convert to C uint32_t"); - return -1; - } -#endif - *pvalue = (uint32_t)value; - return 0; -} - -static inline int PyLong_AsUInt64(PyObject *obj, uint64_t *pvalue) -{ - Py_BUILD_ASSERT(sizeof(long long) == 8); - unsigned long long value = PyLong_AsUnsignedLongLong(obj); - if (value == (unsigned long long)-1 && PyErr_Occurred()) { - return -1; - } - *pvalue = (uint64_t)value; - return 0; -} -#endif - - -// gh-102471 added import and export API for integers to 3.14.0a2. -#if PY_VERSION_HEX < 0x030E00A2 && PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) -// Helpers to access PyLongObject internals. -static inline void -_PyLong_SetSignAndDigitCount(PyLongObject *op, int sign, Py_ssize_t size) -{ -#if PY_VERSION_HEX >= 0x030C0000 - op->long_value.lv_tag = (uintptr_t)(1 - sign) | ((uintptr_t)(size) << 3); -#elif PY_VERSION_HEX >= 0x030900A4 - Py_SET_SIZE(op, sign * size); -#else - Py_SIZE(op) = sign * size; -#endif -} - -static inline Py_ssize_t -_PyLong_DigitCount(const PyLongObject *op) -{ -#if PY_VERSION_HEX >= 0x030C0000 - return (Py_ssize_t)(op->long_value.lv_tag >> 3); -#else - return _PyLong_Sign((PyObject*)op) < 0 ? -Py_SIZE(op) : Py_SIZE(op); -#endif -} - -static inline digit* -_PyLong_GetDigits(const PyLongObject *op) -{ -#if PY_VERSION_HEX >= 0x030C0000 - return (digit*)(op->long_value.ob_digit); -#else - return (digit*)(op->ob_digit); -#endif -} - -typedef struct PyLongLayout { - uint8_t bits_per_digit; - uint8_t digit_size; - int8_t digits_order; - int8_t digit_endianness; -} PyLongLayout; - -typedef struct PyLongExport { - int64_t value; - uint8_t negative; - Py_ssize_t ndigits; - const void *digits; - Py_uintptr_t _reserved; -} PyLongExport; - -typedef struct PyLongWriter PyLongWriter; - -static inline const PyLongLayout* -PyLong_GetNativeLayout(void) -{ - static const PyLongLayout PyLong_LAYOUT = { - PyLong_SHIFT, - sizeof(digit), - -1, // least significant first - PY_LITTLE_ENDIAN ? -1 : 1, - }; - - return &PyLong_LAYOUT; -} - -static inline int -PyLong_Export(PyObject *obj, PyLongExport *export_long) -{ - if (!PyLong_Check(obj)) { - memset(export_long, 0, sizeof(*export_long)); - PyErr_Format(PyExc_TypeError, "expected int, got %s", - Py_TYPE(obj)->tp_name); - return -1; - } - - // Fast-path: try to convert to a int64_t - PyLongObject *self = (PyLongObject*)obj; - int overflow; -#if SIZEOF_LONG == 8 - long value = PyLong_AsLongAndOverflow(obj, &overflow); -#else - // Windows has 32-bit long, so use 64-bit long long instead - long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); -#endif - Py_BUILD_ASSERT(sizeof(value) == sizeof(int64_t)); - // the function cannot fail since obj is a PyLongObject - assert(!(value == -1 && PyErr_Occurred())); - - if (!overflow) { - export_long->value = value; - export_long->negative = 0; - export_long->ndigits = 0; - export_long->digits = 0; - export_long->_reserved = 0; - } - else { - export_long->value = 0; - export_long->negative = _PyLong_Sign(obj) < 0; - export_long->ndigits = _PyLong_DigitCount(self); - if (export_long->ndigits == 0) { - export_long->ndigits = 1; - } - export_long->digits = _PyLong_GetDigits(self); - export_long->_reserved = (Py_uintptr_t)Py_NewRef(obj); - } - return 0; -} - -static inline void -PyLong_FreeExport(PyLongExport *export_long) -{ - PyObject *obj = (PyObject*)export_long->_reserved; - - if (obj) { - export_long->_reserved = 0; - Py_DECREF(obj); - } -} - -static inline PyLongWriter* -PyLongWriter_Create(int negative, Py_ssize_t ndigits, void **digits) -{ - if (ndigits <= 0) { - PyErr_SetString(PyExc_ValueError, "ndigits must be positive"); - return NULL; - } - assert(digits != NULL); - - PyLongObject *obj = _PyLong_New(ndigits); - if (obj == NULL) { - return NULL; - } - _PyLong_SetSignAndDigitCount(obj, negative?-1:1, ndigits); - - *digits = _PyLong_GetDigits(obj); - return (PyLongWriter*)obj; -} - -static inline void -PyLongWriter_Discard(PyLongWriter *writer) -{ - PyLongObject *obj = (PyLongObject *)writer; - - assert(Py_REFCNT(obj) == 1); - Py_DECREF(obj); -} - -static inline PyObject* -PyLongWriter_Finish(PyLongWriter *writer) -{ - PyObject *obj = (PyObject *)writer; - PyLongObject *self = (PyLongObject*)obj; - Py_ssize_t j = _PyLong_DigitCount(self); - Py_ssize_t i = j; - int sign = _PyLong_Sign(obj); - - assert(Py_REFCNT(obj) == 1); - - // Normalize and get singleton if possible - while (i > 0 && _PyLong_GetDigits(self)[i-1] == 0) { - --i; - } - if (i != j) { - if (i == 0) { - sign = 0; - } - _PyLong_SetSignAndDigitCount(self, sign, i); - } - if (i <= 1) { - long val = sign * (long)(_PyLong_GetDigits(self)[0]); - Py_DECREF(obj); - return PyLong_FromLong(val); - } - - return obj; -} -#endif - - -#if PY_VERSION_HEX < 0x030C00A3 -# define Py_T_SHORT T_SHORT -# define Py_T_INT T_INT -# define Py_T_LONG T_LONG -# define Py_T_FLOAT T_FLOAT -# define Py_T_DOUBLE T_DOUBLE -# define Py_T_STRING T_STRING -# define _Py_T_OBJECT T_OBJECT -# define Py_T_CHAR T_CHAR -# define Py_T_BYTE T_BYTE -# define Py_T_UBYTE T_UBYTE -# define Py_T_USHORT T_USHORT -# define Py_T_UINT T_UINT -# define Py_T_ULONG T_ULONG -# define Py_T_STRING_INPLACE T_STRING_INPLACE -# define Py_T_BOOL T_BOOL -# define Py_T_OBJECT_EX T_OBJECT_EX -# define Py_T_LONGLONG T_LONGLONG -# define Py_T_ULONGLONG T_ULONGLONG -# define Py_T_PYSSIZET T_PYSSIZET - -# if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) -# define _Py_T_NONE T_NONE -# endif - -# define Py_READONLY READONLY -# define Py_AUDIT_READ READ_RESTRICTED -# define _Py_WRITE_RESTRICTED PY_WRITE_RESTRICTED -#endif - - -// gh-127350 added Py_fopen() and Py_fclose() to Python 3.14a4 -#if PY_VERSION_HEX < 0x030E00A4 -static inline FILE* Py_fopen(PyObject *path, const char *mode) -{ -#if 0x030400A2 <= PY_VERSION_HEX && !defined(PYPY_VERSION) - PyAPI_FUNC(FILE*) _Py_fopen_obj(PyObject *path, const char *mode); - - return _Py_fopen_obj(path, mode); -#else - FILE *f; - PyObject *bytes; -#if PY_VERSION_HEX >= 0x03000000 - if (!PyUnicode_FSConverter(path, &bytes)) { - return NULL; - } -#else - if (!PyString_Check(path)) { - PyErr_SetString(PyExc_TypeError, "except str"); - return NULL; - } - bytes = Py_NewRef(path); -#endif - const char *path_bytes = PyBytes_AS_STRING(bytes); - - f = fopen(path_bytes, mode); - Py_DECREF(bytes); - - if (f == NULL) { - PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); - return NULL; - } - return f; -#endif -} - -static inline int Py_fclose(FILE *file) -{ - return fclose(file); -} -#endif - - -#if 0x03080000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030E0000 && !defined(PYPY_VERSION) -static inline PyObject* -PyConfig_Get(const char *name) -{ - typedef enum { - _PyConfig_MEMBER_INT, - _PyConfig_MEMBER_UINT, - _PyConfig_MEMBER_ULONG, - _PyConfig_MEMBER_BOOL, - _PyConfig_MEMBER_WSTR, - _PyConfig_MEMBER_WSTR_OPT, - _PyConfig_MEMBER_WSTR_LIST, - } PyConfigMemberType; - - typedef struct { - const char *name; - size_t offset; - PyConfigMemberType type; - const char *sys_attr; - } PyConfigSpec; - -#define PYTHONCAPI_COMPAT_SPEC(MEMBER, TYPE, sys_attr) \ - {#MEMBER, offsetof(PyConfig, MEMBER), \ - _PyConfig_MEMBER_##TYPE, sys_attr} - - static const PyConfigSpec config_spec[] = { - PYTHONCAPI_COMPAT_SPEC(argv, WSTR_LIST, "argv"), - PYTHONCAPI_COMPAT_SPEC(base_exec_prefix, WSTR_OPT, "base_exec_prefix"), - PYTHONCAPI_COMPAT_SPEC(base_executable, WSTR_OPT, "_base_executable"), - PYTHONCAPI_COMPAT_SPEC(base_prefix, WSTR_OPT, "base_prefix"), - PYTHONCAPI_COMPAT_SPEC(bytes_warning, UINT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(exec_prefix, WSTR_OPT, "exec_prefix"), - PYTHONCAPI_COMPAT_SPEC(executable, WSTR_OPT, "executable"), - PYTHONCAPI_COMPAT_SPEC(inspect, BOOL, _Py_NULL), -#if 0x030C0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(int_max_str_digits, UINT, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(interactive, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(module_search_paths, WSTR_LIST, "path"), - PYTHONCAPI_COMPAT_SPEC(optimization_level, UINT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(parser_debug, BOOL, _Py_NULL), -#if 0x03090000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(platlibdir, WSTR, "platlibdir"), -#endif - PYTHONCAPI_COMPAT_SPEC(prefix, WSTR_OPT, "prefix"), - PYTHONCAPI_COMPAT_SPEC(pycache_prefix, WSTR_OPT, "pycache_prefix"), - PYTHONCAPI_COMPAT_SPEC(quiet, BOOL, _Py_NULL), -#if 0x030B0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(stdlib_dir, WSTR_OPT, "_stdlib_dir"), -#endif - PYTHONCAPI_COMPAT_SPEC(use_environment, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(verbose, UINT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(warnoptions, WSTR_LIST, "warnoptions"), - PYTHONCAPI_COMPAT_SPEC(write_bytecode, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(xoptions, WSTR_LIST, "_xoptions"), - PYTHONCAPI_COMPAT_SPEC(buffered_stdio, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(check_hash_pycs_mode, WSTR, _Py_NULL), -#if 0x030B0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(code_debug_ranges, BOOL, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(configure_c_stdio, BOOL, _Py_NULL), -#if 0x030D0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(cpu_count, INT, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(dev_mode, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(dump_refs, BOOL, _Py_NULL), -#if 0x030B0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(dump_refs_file, WSTR_OPT, _Py_NULL), -#endif -#ifdef Py_GIL_DISABLED - PYTHONCAPI_COMPAT_SPEC(enable_gil, INT, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(faulthandler, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(filesystem_encoding, WSTR, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(filesystem_errors, WSTR, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(hash_seed, ULONG, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(home, WSTR_OPT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(import_time, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(install_signal_handlers, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(isolated, BOOL, _Py_NULL), -#ifdef MS_WINDOWS - PYTHONCAPI_COMPAT_SPEC(legacy_windows_stdio, BOOL, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(malloc_stats, BOOL, _Py_NULL), -#if 0x030A0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(orig_argv, WSTR_LIST, "orig_argv"), -#endif - PYTHONCAPI_COMPAT_SPEC(parse_argv, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(pathconfig_warnings, BOOL, _Py_NULL), -#if 0x030C0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(perf_profiling, UINT, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(program_name, WSTR, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(run_command, WSTR_OPT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(run_filename, WSTR_OPT, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(run_module, WSTR_OPT, _Py_NULL), -#if 0x030B0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(safe_path, BOOL, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(show_ref_count, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(site_import, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(skip_source_first_line, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(stdio_encoding, WSTR, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(stdio_errors, WSTR, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(tracemalloc, UINT, _Py_NULL), -#if 0x030B0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(use_frozen_modules, BOOL, _Py_NULL), -#endif - PYTHONCAPI_COMPAT_SPEC(use_hash_seed, BOOL, _Py_NULL), - PYTHONCAPI_COMPAT_SPEC(user_site_directory, BOOL, _Py_NULL), -#if 0x030A0000 <= PY_VERSION_HEX - PYTHONCAPI_COMPAT_SPEC(warn_default_encoding, BOOL, _Py_NULL), -#endif - }; - -#undef PYTHONCAPI_COMPAT_SPEC - - const PyConfigSpec *spec; - int found = 0; - for (size_t i=0; i < sizeof(config_spec) / sizeof(config_spec[0]); i++) { - spec = &config_spec[i]; - if (strcmp(spec->name, name) == 0) { - found = 1; - break; - } - } - if (found) { - if (spec->sys_attr != NULL) { - PyObject *value = PySys_GetObject(spec->sys_attr); - if (value == NULL) { - PyErr_Format(PyExc_RuntimeError, "lost sys.%s", spec->sys_attr); - return NULL; - } - return Py_NewRef(value); - } - - PyAPI_FUNC(const PyConfig*) _Py_GetConfig(void); - - const PyConfig *config = _Py_GetConfig(); - void *member = (char *)config + spec->offset; - switch (spec->type) { - case _PyConfig_MEMBER_INT: - case _PyConfig_MEMBER_UINT: - { - int value = *(int *)member; - return PyLong_FromLong(value); - } - case _PyConfig_MEMBER_BOOL: - { - int value = *(int *)member; - return PyBool_FromLong(value != 0); - } - case _PyConfig_MEMBER_ULONG: - { - unsigned long value = *(unsigned long *)member; - return PyLong_FromUnsignedLong(value); - } - case _PyConfig_MEMBER_WSTR: - case _PyConfig_MEMBER_WSTR_OPT: - { - wchar_t *wstr = *(wchar_t **)member; - if (wstr != NULL) { - return PyUnicode_FromWideChar(wstr, -1); - } - else { - return Py_NewRef(Py_None); - } - } - case _PyConfig_MEMBER_WSTR_LIST: - { - const PyWideStringList *list = (const PyWideStringList *)member; - PyObject *tuple = PyTuple_New(list->length); - if (tuple == NULL) { - return NULL; - } - - for (Py_ssize_t i = 0; i < list->length; i++) { - PyObject *item = PyUnicode_FromWideChar(list->items[i], -1); - if (item == NULL) { - Py_DECREF(tuple); - return NULL; - } - PyTuple_SET_ITEM(tuple, i, item); - } - return tuple; - } - default: - Py_UNREACHABLE(); - } - } - - PyErr_Format(PyExc_ValueError, "unknown config option name: %s", name); - return NULL; -} - -static inline int -PyConfig_GetInt(const char *name, int *value) -{ - PyObject *obj = PyConfig_Get(name); - if (obj == NULL) { - return -1; - } - - if (!PyLong_Check(obj)) { - Py_DECREF(obj); - PyErr_Format(PyExc_TypeError, "config option %s is not an int", name); - return -1; - } - - int as_int = PyLong_AsInt(obj); - Py_DECREF(obj); - if (as_int == -1 && PyErr_Occurred()) { - PyErr_Format(PyExc_OverflowError, - "config option %s value does not fit into a C int", name); - return -1; - } - - *value = as_int; - return 0; -} -#endif // PY_VERSION_HEX > 0x03090000 && !defined(PYPY_VERSION) - -// gh-133144 added PyUnstable_Object_IsUniquelyReferenced() to Python 3.14.0b1. -// Adapted from _PyObject_IsUniquelyReferenced() implementation. -#if PY_VERSION_HEX < 0x030E00B0 -static inline int PyUnstable_Object_IsUniquelyReferenced(PyObject *obj) -{ -#if !defined(Py_GIL_DISABLED) - return Py_REFCNT(obj) == 1; -#else - // NOTE: the entire ob_ref_shared field must be zero, including flags, to - // ensure that other threads cannot concurrently create new references to - // this object. - return (_Py_IsOwnedByCurrentThread(obj) && - _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local) == 1 && - _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared) == 0); -#endif -} -#endif - -// gh-128926 added PyUnstable_TryIncRef() and PyUnstable_EnableTryIncRef() to -// Python 3.14.0a5. Adapted from _Py_TryIncref() and _PyObject_SetMaybeWeakref(). -#if PY_VERSION_HEX < 0x030E00A5 -static inline int PyUnstable_TryIncRef(PyObject *op) -{ -#ifndef Py_GIL_DISABLED - if (Py_REFCNT(op) > 0) { - Py_INCREF(op); - return 1; - } - return 0; -#else - // _Py_TryIncrefFast() - uint32_t local = _Py_atomic_load_uint32_relaxed(&op->ob_ref_local); - local += 1; - if (local == 0) { - // immortal - return 1; - } - if (_Py_IsOwnedByCurrentThread(op)) { - _Py_INCREF_STAT_INC(); - _Py_atomic_store_uint32_relaxed(&op->ob_ref_local, local); -#ifdef Py_REF_DEBUG - _Py_INCREF_IncRefTotal(); -#endif - return 1; - } - - // _Py_TryIncRefShared() - Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); - for (;;) { - // If the shared refcount is zero and the object is either merged - // or may not have weak references, then we cannot incref it. - if (shared == 0 || shared == _Py_REF_MERGED) { - return 0; - } - - if (_Py_atomic_compare_exchange_ssize( - &op->ob_ref_shared, - &shared, - shared + (1 << _Py_REF_SHARED_SHIFT))) { -#ifdef Py_REF_DEBUG - _Py_INCREF_IncRefTotal(); -#endif - _Py_INCREF_STAT_INC(); - return 1; - } - } -#endif -} - -static inline void PyUnstable_EnableTryIncRef(PyObject *op) -{ -#ifdef Py_GIL_DISABLED - // _PyObject_SetMaybeWeakref() - if (_Py_IsImmortal(op)) { - return; - } - for (;;) { - Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); - if ((shared & _Py_REF_SHARED_FLAG_MASK) != 0) { - // Nothing to do if it's in WEAKREFS, QUEUED, or MERGED states. - return; - } - if (_Py_atomic_compare_exchange_ssize( - &op->ob_ref_shared, &shared, shared | _Py_REF_MAYBE_WEAKREF)) { - return; - } - } -#else - (void)op; // unused argument -#endif -} -#endif - - -#if PY_VERSION_HEX < 0x030F0000 -static inline PyObject* -PySys_GetAttrString(const char *name) -{ -#if PY_VERSION_HEX >= 0x03000000 - PyObject *value = Py_XNewRef(PySys_GetObject(name)); -#else - PyObject *value = Py_XNewRef(PySys_GetObject((char*)name)); -#endif - if (value != NULL) { - return value; - } - if (!PyErr_Occurred()) { - PyErr_Format(PyExc_RuntimeError, "lost sys.%s", name); - } - return NULL; -} - -static inline PyObject* -PySys_GetAttr(PyObject *name) -{ -#if PY_VERSION_HEX >= 0x03000000 - const char *name_str = PyUnicode_AsUTF8(name); -#else - const char *name_str = PyString_AsString(name); -#endif - if (name_str == NULL) { - return NULL; - } - - return PySys_GetAttrString(name_str); -} - -static inline int -PySys_GetOptionalAttrString(const char *name, PyObject **value) -{ -#if PY_VERSION_HEX >= 0x03000000 - *value = Py_XNewRef(PySys_GetObject(name)); -#else - *value = Py_XNewRef(PySys_GetObject((char*)name)); -#endif - if (*value != NULL) { - return 1; - } - return 0; -} - -static inline int -PySys_GetOptionalAttr(PyObject *name, PyObject **value) -{ -#if PY_VERSION_HEX >= 0x03000000 - const char *name_str = PyUnicode_AsUTF8(name); -#else - const char *name_str = PyString_AsString(name); -#endif - if (name_str == NULL) { - *value = NULL; - return -1; - } - - return PySys_GetOptionalAttrString(name_str, value); -} -#endif // PY_VERSION_HEX < 0x030F00A1 - - -#if PY_VERSION_HEX < 0x030F00A1 -typedef struct PyBytesWriter { - char small_buffer[256]; - PyObject *obj; - Py_ssize_t size; -} PyBytesWriter; - -static inline Py_ssize_t -_PyBytesWriter_GetAllocated(PyBytesWriter *writer) -{ - if (writer->obj == NULL) { - return sizeof(writer->small_buffer); - } - else { - return PyBytes_GET_SIZE(writer->obj); - } -} - - -static inline int -_PyBytesWriter_Resize_impl(PyBytesWriter *writer, Py_ssize_t size, - int resize) -{ - int overallocate = resize; - assert(size >= 0); - - if (size <= _PyBytesWriter_GetAllocated(writer)) { - return 0; - } - - if (overallocate) { -#ifdef MS_WINDOWS - /* On Windows, overallocate by 50% is the best factor */ - if (size <= (PY_SSIZE_T_MAX - size / 2)) { - size += size / 2; - } -#else - /* On Linux, overallocate by 25% is the best factor */ - if (size <= (PY_SSIZE_T_MAX - size / 4)) { - size += size / 4; - } -#endif - } - - if (writer->obj != NULL) { - if (_PyBytes_Resize(&writer->obj, size)) { - return -1; - } - assert(writer->obj != NULL); - } - else { - writer->obj = PyBytes_FromStringAndSize(NULL, size); - if (writer->obj == NULL) { - return -1; - } - - if (resize) { - assert((size_t)size > sizeof(writer->small_buffer)); - memcpy(PyBytes_AS_STRING(writer->obj), - writer->small_buffer, - sizeof(writer->small_buffer)); - } - } - return 0; -} - -static inline void* -PyBytesWriter_GetData(PyBytesWriter *writer) -{ - if (writer->obj == NULL) { - return writer->small_buffer; - } - else { - return PyBytes_AS_STRING(writer->obj); - } -} - -static inline Py_ssize_t -PyBytesWriter_GetSize(PyBytesWriter *writer) -{ - return writer->size; -} - -static inline void -PyBytesWriter_Discard(PyBytesWriter *writer) -{ - if (writer == NULL) { - return; - } - - Py_XDECREF(writer->obj); - PyMem_Free(writer); -} - -static inline PyBytesWriter* -PyBytesWriter_Create(Py_ssize_t size) -{ - if (size < 0) { - PyErr_SetString(PyExc_ValueError, "size must be >= 0"); - return NULL; - } - - PyBytesWriter *writer = (PyBytesWriter*)PyMem_Malloc(sizeof(PyBytesWriter)); - if (writer == NULL) { - PyErr_NoMemory(); - return NULL; - } - - writer->obj = NULL; - writer->size = 0; - - if (size >= 1) { - if (_PyBytesWriter_Resize_impl(writer, size, 0) < 0) { - PyBytesWriter_Discard(writer); - return NULL; - } - writer->size = size; - } - return writer; -} - -static inline PyObject* -PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size) -{ - PyObject *result; - if (size == 0) { - result = PyBytes_FromStringAndSize("", 0); - } - else if (writer->obj != NULL) { - if (size != PyBytes_GET_SIZE(writer->obj)) { - if (_PyBytes_Resize(&writer->obj, size)) { - goto error; - } - } - result = writer->obj; - writer->obj = NULL; - } - else { - result = PyBytes_FromStringAndSize(writer->small_buffer, size); - } - PyBytesWriter_Discard(writer); - return result; - -error: - PyBytesWriter_Discard(writer); - return NULL; -} - -static inline PyObject* -PyBytesWriter_Finish(PyBytesWriter *writer) -{ - return PyBytesWriter_FinishWithSize(writer, writer->size); -} - -static inline PyObject* -PyBytesWriter_FinishWithPointer(PyBytesWriter *writer, void *buf) -{ - Py_ssize_t size = (char*)buf - (char*)PyBytesWriter_GetData(writer); - if (size < 0 || size > _PyBytesWriter_GetAllocated(writer)) { - PyBytesWriter_Discard(writer); - PyErr_SetString(PyExc_ValueError, "invalid end pointer"); - return NULL; - } - - return PyBytesWriter_FinishWithSize(writer, size); -} - -static inline int -PyBytesWriter_Resize(PyBytesWriter *writer, Py_ssize_t size) -{ - if (size < 0) { - PyErr_SetString(PyExc_ValueError, "size must be >= 0"); - return -1; - } - if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { - return -1; - } - writer->size = size; - return 0; -} - -static inline int -PyBytesWriter_Grow(PyBytesWriter *writer, Py_ssize_t size) -{ - if (size < 0 && writer->size + size < 0) { - PyErr_SetString(PyExc_ValueError, "invalid size"); - return -1; - } - if (size > PY_SSIZE_T_MAX - writer->size) { - PyErr_NoMemory(); - return -1; - } - size = writer->size + size; - - if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { - return -1; - } - writer->size = size; - return 0; -} - -static inline void* -PyBytesWriter_GrowAndUpdatePointer(PyBytesWriter *writer, - Py_ssize_t size, void *buf) -{ - Py_ssize_t pos = (char*)buf - (char*)PyBytesWriter_GetData(writer); - if (PyBytesWriter_Grow(writer, size) < 0) { - return NULL; - } - return (char*)PyBytesWriter_GetData(writer) + pos; -} - -static inline int -PyBytesWriter_WriteBytes(PyBytesWriter *writer, - const void *bytes, Py_ssize_t size) -{ - if (size < 0) { - size_t len = strlen((const char*)bytes); - if (len > (size_t)PY_SSIZE_T_MAX) { - PyErr_NoMemory(); - return -1; - } - size = (Py_ssize_t)len; - } - - Py_ssize_t pos = writer->size; - if (PyBytesWriter_Grow(writer, size) < 0) { - return -1; - } - char *buf = (char*)PyBytesWriter_GetData(writer); - memcpy(buf + pos, bytes, (size_t)size); - return 0; -} - -static inline int -PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) - Py_GCC_ATTRIBUTE((format(printf, 2, 3))); - -static inline int -PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) -{ - va_list vargs; - va_start(vargs, format); - PyObject *str = PyBytes_FromFormatV(format, vargs); - va_end(vargs); - - if (str == NULL) { - return -1; - } - int res = PyBytesWriter_WriteBytes(writer, - PyBytes_AS_STRING(str), - PyBytes_GET_SIZE(str)); - Py_DECREF(str); - return res; -} -#endif // PY_VERSION_HEX < 0x030F00A1 - - -#if PY_VERSION_HEX < 0x030F00A1 -static inline PyObject* -PyTuple_FromArray(PyObject *const *array, Py_ssize_t size) -{ - PyObject *tuple = PyTuple_New(size); - if (tuple == NULL) { - return NULL; - } - for (Py_ssize_t i=0; i < size; i++) { - PyObject *item = array[i]; - PyTuple_SET_ITEM(tuple, i, Py_NewRef(item)); - } - return tuple; -} -#endif - - -#if PY_VERSION_HEX < 0x030F00A1 -static inline Py_hash_t -PyUnstable_Unicode_GET_CACHED_HASH(PyObject *op) -{ -#ifdef PYPY_VERSION - (void)op; // unused argument - return -1; -#elif PY_VERSION_HEX >= 0x03000000 - return ((PyASCIIObject*)op)->hash; -#else - return ((PyUnicodeObject*)op)->hash; -#endif -} -#endif - #ifdef __cplusplus } From 1727a71cb69378eba3cd1169074d0101a1a43c28 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 7 Nov 2025 11:19:37 -0800 Subject: [PATCH 227/651] Create pallas test shard (#167143) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167143 Approved by: https://github.com/malfet ghstack dependencies: #167243 --- .ci/docker/build.sh | 7 +++++ .ci/docker/ci_commit_pins/jax.txt | 1 + .ci/docker/common/install_jax.sh | 40 +++++++++++++++++++++++++ .ci/docker/ubuntu/Dockerfile | 9 ++++++ .ci/pytorch/test.sh | 7 +++++ .github/workflows/docker-builds.yml | 1 + .github/workflows/inductor-unittest.yml | 26 ++++++++++++++++ 7 files changed, 91 insertions(+) create mode 100644 .ci/docker/ci_commit_pins/jax.txt create mode 100755 .ci/docker/common/install_jax.sh diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 5609b9e30dc2b..90c87b55ea416 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -260,6 +260,12 @@ case "$tag" in HALIDE=yes TRITON=yes ;; + pytorch-linux-jammy-cuda13.0-py3.12-pallas) + CUDA_VERSION=13.0.0 + ANACONDA_PYTHON_VERSION=3.12 + GCC_VERSION=11 + PALLAS=yes + ;; pytorch-linux-jammy-py3.12-triton-cpu) CUDA_VERSION=12.6 ANACONDA_PYTHON_VERSION=3.12 @@ -381,6 +387,7 @@ docker build \ --build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \ --build-arg "EXECUTORCH=${EXECUTORCH}" \ --build-arg "HALIDE=${HALIDE}" \ + --build-arg "PALLAS=${PALLAS}" \ --build-arg "XPU_VERSION=${XPU_VERSION}" \ --build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \ --build-arg "ACL=${ACL:-}" \ diff --git a/.ci/docker/ci_commit_pins/jax.txt b/.ci/docker/ci_commit_pins/jax.txt new file mode 100644 index 0000000000000..a3df0a6959e15 --- /dev/null +++ b/.ci/docker/ci_commit_pins/jax.txt @@ -0,0 +1 @@ +0.8.0 diff --git a/.ci/docker/common/install_jax.sh b/.ci/docker/common/install_jax.sh new file mode 100755 index 0000000000000..184aedf0f94fe --- /dev/null +++ b/.ci/docker/common/install_jax.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -ex + +source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" + +# Get the pinned JAX version (same for all CUDA versions) +JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax) + +function install_jax_12() { + echo "Installing JAX ${JAX_VERSION} with CUDA 12 support" + pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + + # Verify installation + python -c "import jax" # check for errors + echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12" +} + +function install_jax_13() { + echo "Installing JAX ${JAX_VERSION} with CUDA 13 support" + pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + + # Verify installation + python -c "import jax" # check for errors + echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13" +} + +# idiomatic parameter and option handling in sh +while test $# -gt 0 +do + case "$1" in + 12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12; + ;; + 13.0|13.0.*) install_jax_13; + ;; + *) echo "bad argument $1"; exit 1 + ;; + esac + shift +done diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 84a74114c381e..2081dcbdffd17 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi RUN rm install_halide.sh common_utils.sh halide.txt +ARG PALLAS +ARG CUDA_VERSION +# Install JAX with CUDA support (for Pallas) +COPY ./common/install_jax.sh install_jax.sh +COPY ./common/common_utils.sh common_utils.sh +COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt +RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi +RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt + ARG ONNX # Install ONNX dependencies COPY ./common/install_onnx.sh ./common/common_utils.sh ./ diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 26996b5a32d56..37adb0282c999 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -824,6 +824,11 @@ test_inductor_halide() { assert_git_not_dirty } +test_inductor_pallas() { + python test/run_test.py --include inductor/test_pallas.py --verbose + assert_git_not_dirty +} + test_inductor_triton_cpu() { python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose assert_git_not_dirty @@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then test_inductor_distributed elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then test_inductor_halide +elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then + test_inductor_pallas elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then test_inductor_triton_cpu elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 6d3a5c321a1eb..f632d4a858abb 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -67,6 +67,7 @@ jobs: pytorch-linux-jammy-py3.10-gcc11, pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, + pytorch-linux-jammy-cuda13.0-py3.12-pallas, pytorch-linux-jammy-xpu-n-1-py3, pytorch-linux-noble-xpu-n-py3, pytorch-linux-noble-xpu-n-py3-inductor-benchmarks, diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 3ce917567aec2..f55267caba93f 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -81,6 +81,32 @@ jobs: test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }} secrets: inherit + inductor-pallas-build: + name: inductor-pallas-build + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas + cuda-arch-list: '8.9' + runner: linux.8xlarge.memory + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + test-matrix: | + { include: [ + { config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + ]} + secrets: inherit + + inductor-pallas-test: + name: inductor-pallas-test + uses: ./.github/workflows/_linux-test.yml + needs: inductor-pallas-build + with: + build-environment: linux-jammy-py3.12-gcc11 + docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }} + test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }} + secrets: inherit + inductor-triton-cpu-build: name: inductor-triton-cpu-build uses: ./.github/workflows/_linux-build.yml From 84b2147b8527436f80df4f14377bd865a5f3fa3f Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 7 Nov 2025 09:36:18 -0800 Subject: [PATCH 228/651] Introducing the StableIValue representation of list :D (#165953) Some important notes: a) Just like IValues steal the ownership of ArrayRefs and any std::vectors in order to convert the inner elements into IValues, we do the same thing with StableIValue. This O(N) traverse is ineluctable. b) As a result, since StableIValues are owning and our contract is that to(StableIValue) transfers ownership, you cannot ever convert from StableIValue to a nonowning HeaderOnlyArrayRef. We handle memory similar to AtenTensorHandle, but we have a StableListHandle! Pull Request resolved: https://github.com/pytorch/pytorch/pull/165953 Approved by: https://github.com/malfet ghstack dependencies: #164991, #165152, #165153 --- .../libtorch_agnostic/csrc/kernel.cpp | 51 ++++++++++++ .../libtorch_agnostic/ops.py | 42 ++++++++++ .../test/test_libtorch_agnostic.py | 51 ++++++++++++ torch/csrc/shim_common.cpp | 83 ++++++++++++++++++- torch/csrc/shim_conversion_utils.h | 22 +++++ torch/csrc/stable/c/shim.h | 28 +++++++ torch/csrc/stable/stableivalue_conversions.h | 73 ++++++++++++++++ 7 files changed, 349 insertions(+), 1 deletion(-) create mode 100644 torch/csrc/shim_conversion_utils.h diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 7154322641c32..9f4079ab79748 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -478,6 +478,56 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("my_amax_vec", &boxed_my_amax_vec); } +std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {from(self), from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); + return to>(stack[0]); +} + +void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + // Why is the following NOT to>(stack[0])? Because calling `to` + // on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef + // is not owning, so it cannot safely steward the result of the to<>. + auto res = my__foreach_mul(to>(stack[0]), to>(stack[1])); + stack[0] = from(res); +} + +void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { + std::array stack = {from(self), from(other)}; + aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); +} + +void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + my__foreach_mul_(to>(stack[0]), to>(stack[1])); +} + +std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { + // This function tests that my__foreach_mul can take in std::initializer_lists + // in addition to std::vectors. + Tensor t1_1 = my_clone(t1); + Tensor t1_2 = my_clone(t1); + Tensor t2_1 = my_clone(t2); + Tensor t2_2 = my_clone(t2); + return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2}); +} + +void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = make_tensor_clones_and_call_foreach(to(stack[0]), to(stack[1])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]"); + m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()"); + m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("my__foreach_mul", &boxed_my__foreach_mul); + m.impl("my__foreach_mul_", &boxed_my__foreach_mul_); + m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach); +} + // Test functions for torch::stable::accelerator APIs #ifdef LAE_USE_CUDA @@ -565,4 +615,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_stream", &boxed_test_stream); m.impl("test_get_current_device_index", &boxed_test_get_current_device_index); } + #endif // LAE_USE_CUDA diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 0000d667e1cbc..e0e5cef216375 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -333,3 +333,45 @@ def my_new_zeros_dtype_variant(t) -> Tensor: Returns: New zeros tensor """ return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t) + + +def my__foreach_mul_(tensors, others) -> (): + """ + Updates tensors to be the result of pointwise multiplying with others. + + Args: + tensors: list of tensors + others: list of tensors (with the same corresponding shapes as tensors) + + Returns: nothing, tensors is updated in place. + """ + torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others) + + +def my__foreach_mul(tensors, others) -> list[Tensor]: + """ + Returns a list of tensors that are the results of pointwise multiplying + tensors and others. + + Args: + tensors: list of tensors + others: list of tensors (with the same corresponding shapes as tensors) + + Returns: list of multiplied tensors + """ + return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others) + + +def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]: + """ + Returns a list of 2 tensors corresponding to the square of the inputs. + + Args: + t1: Tensor + t2: Tensor + + Returns: list of [t1^2, t2^2] + """ + return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default( + t1, t2 + ) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 35610332a36cd..e94c740861a11 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -367,6 +367,57 @@ def test_my_clone(self, device): self.assertNotEqual(result.data_ptr(), expected.data_ptr()) self.assertEqual(result.stride(), expected.stride()) + def test_my__foreach_mul_(self, device): + import libtorch_agnostic + + N = 5 + tensors = [torch.rand(32, 16, device=device) for _ in range(N)] + tensors_c = [t.clone() for t in tensors] + others = [torch.rand(32, 16, device=device) for _ in range(N)] + + libtorch_agnostic.ops.my__foreach_mul_(tensors, others) + expected_values = torch._foreach_mul(tensors_c, others) + + for tensor_t, expected_t in zip(tensors, expected_values): + self.assertEqual(tensor_t, expected_t) + + def test_my__foreach_mul(self, device): + import libtorch_agnostic + + N = 5 + tensors = [torch.rand(32, 16, device=device) for _ in range(N)] + others = [torch.rand(32, 16, device=device) for _ in range(N)] + + result = libtorch_agnostic.ops.my__foreach_mul(tensors, others) + expected = torch._foreach_mul(tensors, others) + + for result_t, expected_t in zip(result, expected): + self.assertEqual(result_t, expected_t) + + def _make_cuda_tensors(prior_mem): + cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others) + self.assertGreater(torch.cuda.memory_allocated(device), prior_mem) + + expected = torch._foreach_mul(tensors, others) + for result_t, expected_t in zip(cuda_res, expected): + self.assertEqual(result_t, expected_t) + + if tensors[0].is_cuda: + init_mem = torch.cuda.memory_allocated(device) + for _ in range(3): + _make_cuda_tensors(init_mem) + curr_mem = torch.cuda.memory_allocated(device) + self.assertEqual(curr_mem, init_mem) + + def test_make_tensor_clones_and_call_foreach(self, device): + import libtorch_agnostic + + t1 = torch.rand(2, 5, device=device) + t2 = torch.rand(3, 4, device=device) + result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2) + self.assertEqual(result[0], t1 * t1) + self.assertEqual(result[1], t2 * t2) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 23effad1a36b2..15b9b986a3463 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -4,12 +4,65 @@ #include #include #include -#include #include #include +#include #include +AOTITorchError torch_new_list_reserve_size(size_t size, StableListHandle* ret) { + auto list_ptr = std::make_unique>(); + list_ptr->reserve(size); + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret = list_pointer_to_list_handle(list_ptr.release()); }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_list_size(StableListHandle list_handle, size_t* size) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::vector* list = list_handle_to_list_pointer(list_handle); + *size = list->size(); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item( + StableListHandle list_handle, + size_t index, + StableIValue* element) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::vector* list = list_handle_to_list_pointer(list_handle); + *element = list->at(index); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item( + StableListHandle list_handle, + size_t index, + StableIValue element) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::vector* list = list_handle_to_list_pointer(list_handle); + list->at(index) = element; + }); +} + +AOTITorchError torch_list_push_back( + StableListHandle list_handle, + StableIValue element) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::vector* list = list_handle_to_list_pointer(list_handle); + list->push_back(element); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_delete_list(StableListHandle list_handle) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::vector* list_ptr = + list_handle_to_list_pointer(list_handle); + delete list_ptr; + }); +} + static StableIValue from_ivalue( const c10::TypePtr& type, const c10::IValue& ivalue, @@ -71,6 +124,19 @@ static StableIValue from_ivalue( from_ivalue(inner_type, ivalue, extension_build_version)); return torch::stable::detail::_from(sivp, extension_build_version); } + case c10::TypeKind::ListType: { + auto inner_type = type->castRaw()->getElementType(); + auto ivalue_list = ivalue.toList(); + auto stableivalue_list = std::make_unique>(); + stableivalue_list->reserve(ivalue_list.size()); + for (const auto& elem : ivalue_list) { + stableivalue_list->emplace_back( + from_ivalue(inner_type, elem, extension_build_version)); + } + return torch::stable::detail::_from( + list_pointer_to_list_handle(stableivalue_list.release()), + extension_build_version); + } default: { TORCH_CHECK( false, @@ -145,6 +211,21 @@ static c10::IValue to_ivalue( delete sivp; return ival; } + case c10::TypeKind::ListType: { + auto inner_type = type->castRaw()->getElementType(); + auto list_handle = torch::stable::detail::_to( + stable_ivalue, extension_build_version); + std::vector* stableivalue_list = + list_handle_to_list_pointer(list_handle); + auto ivalue_list = c10::impl::GenericList(inner_type); + ivalue_list.reserve(stableivalue_list->size()); + for (const auto& elem : *stableivalue_list) { + ivalue_list.emplace_back( + to_ivalue(inner_type, elem, extension_build_version)); + } + TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); + return ivalue_list; + } default: { TORCH_CHECK( false, diff --git a/torch/csrc/shim_conversion_utils.h b/torch/csrc/shim_conversion_utils.h new file mode 100644 index 0000000000000..e0e1d25e65ef7 --- /dev/null +++ b/torch/csrc/shim_conversion_utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +#include + +inline std::vector* list_handle_to_list_pointer( + StableListHandle handle) { + return reinterpret_cast*>(handle); +} + +inline StableListHandle list_pointer_to_list_handle( + std::vector* list_ptr) { + return reinterpret_cast(list_ptr); +} + +inline StableListHandle new_list_handle(std::vector&& list) { + std::vector* new_list = new std::vector(list); + return list_pointer_to_list_handle(new_list); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 365c954dbe787..ea6cea0726659 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -37,6 +37,34 @@ AOTI_TORCH_EXPORT AOTITorchError torch_library_impl( void (*fn)(StableIValue*, uint64_t, uint64_t), uint64_t extension_build_version); +struct StableListOpaque; +using StableListHandle = StableListOpaque*; + +// returns an owning reference of a StableList. callee is responsible for +// freeing memory. +AOTI_TORCH_EXPORT AOTITorchError +torch_new_list_reserve_size(size_t size, StableListHandle* ret); + +AOTI_TORCH_EXPORT AOTITorchError +torch_list_size(StableListHandle list_handle, size_t* size); + +AOTI_TORCH_EXPORT AOTITorchError torch_list_get_item( + StableListHandle list_handle, + size_t index, + StableIValue* element); + +AOTI_TORCH_EXPORT AOTITorchError torch_list_set_item( + StableListHandle list_handle, + size_t index, + StableIValue element); + +AOTI_TORCH_EXPORT AOTITorchError +torch_list_push_back(StableListHandle list_handle, StableIValue element); + +// deletes the underlying list referenced by list_handle +AOTI_TORCH_EXPORT AOTITorchError +torch_delete_list(StableListHandle list_handle); + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 8004e91b77f8e..6885b1e4bdfeb 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -192,6 +193,46 @@ struct FromImpl { } }; +// Specialization for torch::headeronly::HeaderOnlyArrayRef => StableIValue +// Returns a new owning reference of the underlying list. +template +struct FromImpl> { + static StableIValue call( + const torch::headeronly::HeaderOnlyArrayRef& val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + StableListHandle new_list_handle; + try { + TORCH_ERROR_CODE_CHECK( + torch_new_list_reserve_size(val.size(), &new_list_handle)); + for (const auto& elem : val) { + TORCH_ERROR_CODE_CHECK( + torch_list_push_back(new_list_handle, from(elem))); + } + return from(new_list_handle); + } catch (const std::runtime_error& e) { + if (new_list_handle != nullptr) { + // clean up memory if an error was thrown + TORCH_ERROR_CODE_CHECK(torch_delete_list(new_list_handle)); + } + throw; + } + } +}; + +// Specialization for std::vector => StableIValue, which is implemented the +// same way as HeaderOnlyArrayRef => StableIValue +// Returns a new owning reference of the underlying list. +template +struct FromImpl> { + static StableIValue call( + const std::vector& val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + return from>(val); + } +}; + // ============================================================================= // TO CONVERSIONS (StableIValue -> T) // ============================================================================= @@ -342,6 +383,38 @@ struct ToImpl { } }; +// Specialization for StableIValue => std::vector +// std::vector should be represented as a StableListHandle +// filled with StableIValues +// The new std::vector steals ownership of the underlying elements +// and we free the underlying list referred by the input StableListHandle. +template +struct ToImpl> { + static std::vector call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + auto list_handle = to(val); + size_t size; + try { + TORCH_ERROR_CODE_CHECK(torch_list_size(list_handle, &size)); + std::vector result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + StableIValue element; + TORCH_ERROR_CODE_CHECK(torch_list_get_item(list_handle, i, &element)); + result.push_back(to(element)); + } + TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); + return result; + } catch (const std::runtime_error& e) { + // clean up memory if an exception is thrown, and rethrow + TORCH_ERROR_CODE_CHECK(torch_delete_list(list_handle)); + throw; + } + } +}; + // ============================================================================= // end to helpers for converting between StableIValue and T // ============================================================================= From 46516efa85e25499da12790d8fbb475fd9dadbbd Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 7 Nov 2025 09:36:19 -0800 Subject: [PATCH 229/651] [BE] use undeprecated from/to in libtorch_agnostic tests (#167126) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167126 Approved by: https://github.com/Skylion007 ghstack dependencies: #164991, #165152, #165153, #165953 --- .../libtorch_agnostic/csrc/kernel.cpp | 178 +++++++++--------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 9f4079ab79748..92a4af8b72733 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -67,13 +67,13 @@ Tensor sgd_out_of_place( void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { Tensor res = sgd_out_of_place( - to(stack[0]), - to(stack[1]), - float(to(stack[2])), - to(stack[3]), - to(stack[4])); + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1]), + float(torch::stable::detail::to(stack[2])), + torch::stable::detail::to(stack[3]), + torch::stable::detail::to(stack[4])); - stack[0] = from(res); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY(libtorch_agnostic, m) { @@ -89,8 +89,8 @@ Tensor identity(Tensor t) { } void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor res = identity(to(stack[0])); - stack[0] = from(res); + Tensor res = identity(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) { Tensor my_abs(Tensor t) { const auto num_args = 1; StableIValue stack[num_args]; - stack[0] = from(t); + stack[0] = torch::stable::detail::from(t); aoti_torch_call_dispatcher("aten::abs", "", stack); - return to(stack[0]); + return torch::stable::detail::to(stack[0]); } void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor tensor_res = my_abs(to(stack[0])); - stack[0] = from(tensor_res); + Tensor tensor_res = my_abs(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(tensor_res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) { auto mf = aoti_torch_memory_format_contiguous_format(); - stack[0] = from(t); - stack[1] = from(std::optional(t.scalar_type())); // dtype - stack[2] = from(std::nullopt); // layout - stack[3] = from(std::optional(device)); // device - stack[4] = from(std::optional(false)); // pin_memory - stack[5] = from(std::optional(mf)); // memory_format + stack[0] = torch::stable::detail::from(t); + stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype + stack[2] = torch::stable::detail::from(std::nullopt); // layout + stack[3] = torch::stable::detail::from(std::optional(device)); // device + stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory + stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format aoti_torch_call_dispatcher("aten::ones_like", "", stack); - return to(stack[0]); + return torch::stable::detail::to(stack[0]); } void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor res = my_ones_like(to(stack[0]), stack[1]); - stack[0] = from(res); + Tensor res = my_ones_like(torch::stable::detail::to(stack[0]), stack[1]); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { std::tuple exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) { StableIValue stack_exp[1]; - stack_exp[0] = from(t1); + stack_exp[0] = torch::stable::detail::from(t1); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); StableIValue stack_neg[1]; - stack_neg[0] = from(t2); + stack_neg[0] = torch::stable::detail::from(t2); aoti_torch_call_dispatcher("aten::neg", "", stack_neg); StableIValue stack_is_leaf[1]; - stack_is_leaf[0] = from(t3); + stack_is_leaf[0] = torch::stable::detail::from(t3); aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf); return std::make_tuple( - to(stack_exp[0]), - to(stack_neg[0]), - to(stack_is_leaf[0])); + torch::stable::detail::to(stack_exp[0]), + torch::stable::detail::to(stack_neg[0]), + torch::stable::detail::to(stack_is_leaf[0])); } void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto tuple = exp_neg_is_leaf(to(stack[0]), to(stack[1]), to(stack[2])); - stack[0] = from(std::get<0>(tuple)); - stack[1] = from(std::get<1>(tuple)); - stack[2] = from(std::get<2>(tuple)); + auto tuple = exp_neg_is_leaf(torch::stable::detail::to(stack[0]), torch::stable::detail::to(stack[1]), torch::stable::detail::to(stack[2])); + stack[0] = torch::stable::detail::from(std::get<0>(tuple)); + stack[1] = torch::stable::detail::from(std::get<1>(tuple)); + stack[2] = torch::stable::detail::from(std::get<2>(tuple)); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { Tensor neg_exp(Tensor t) { StableIValue stack[1]; - stack[0] = from(t); + stack[0] = torch::stable::detail::from(t); aoti_torch_call_dispatcher("aten::exp", "", stack); aoti_torch_call_dispatcher("aten::neg", "", stack); - return to(stack[0]); + return torch::stable::detail::to(stack[0]); } void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor res = neg_exp(to(stack[0])); - stack[0] = from(res); + Tensor res = neg_exp(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { Tensor divide_neg_exp(Tensor t) { StableIValue stack_neg[1]; - stack_neg[0] = from(t); + stack_neg[0] = torch::stable::detail::from(t); StableIValue stack_exp[1]; - stack_exp[0] = from(t); + stack_exp[0] = torch::stable::detail::from(t); aoti_torch_call_dispatcher("aten::exp", "", stack_exp); aoti_torch_call_dispatcher("aten::neg", "", stack_neg); @@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) { stack_div[0] = stack_neg[0]; stack_div[1] = stack_exp[0]; aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div); - return to(stack_div[0]); + return torch::stable::detail::to(stack_div[0]); } void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor res = divide_neg_exp(to(stack[0])); - stack[0] = from(res); + Tensor res = divide_neg_exp(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) { } void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - bool res = is_contiguous(to(stack[0])); - stack[0] = from(res); + bool res = is_contiguous(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) { } void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_transpose(to(stack[0]), to(stack[1]), to(stack[2])); + auto res = my_transpose(torch::stable::detail::to(stack[0]), torch::stable::detail::to(stack[1]), torch::stable::detail::to(stack[2])); - stack[0] = from(res); + stack[0] = torch::stable::detail::from(res); } Tensor my_empty_like(Tensor t) { @@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) { } void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_empty_like(to(stack[0])); - stack[0] = from(res); + auto res = my_empty_like(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } bool my_is_cpu(Tensor t) { @@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) { void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_is_cpu(to(stack[0])); - stack[0] = from(res); + auto res = my_is_cpu(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor fill_infinity(Tensor t) { @@ -296,8 +296,8 @@ void boxed_fill_infinity( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = fill_infinity(to(stack[0])); - stack[0] = from(res); + auto res = fill_infinity(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_pad(Tensor t) { @@ -310,8 +310,8 @@ void boxed_my_pad( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_pad(to(stack[0])); - stack[0] = from(res); + auto res = my_pad(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) { @@ -323,11 +323,11 @@ void boxed_my_narrow( uint64_t num_args, uint64_t num_outputs) { auto res = my_narrow( - to(stack[0]), - to(stack[1]), - to(stack[2]), - to(stack[3])); - stack[0] = from(res); + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1]), + torch::stable::detail::to(stack[2]), + torch::stable::detail::to(stack[3])); + stack[0] = torch::stable::detail::from(res); } Tensor my_new_empty_dtype_variant(Tensor t) { @@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) { } void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_new_empty_dtype_variant(to(stack[0])); - stack[0] = from(res); + auto res = my_new_empty_dtype_variant(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_new_zeros_dtype_variant(Tensor t) { @@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) { } void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_new_zeros_dtype_variant(to(stack[0])); - stack[0] = from(res); + auto res = my_new_zeros_dtype_variant(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) { @@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) { } void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor tensor_res = my_copy_(to(stack[0]), to(stack[1]), to(stack[2])); - stack[0] = from(tensor_res); + Tensor tensor_res = my_copy_(torch::stable::detail::to(stack[0]), torch::stable::detail::to(stack[1]), torch::stable::detail::to(stack[2])); + stack[0] = torch::stable::detail::from(tensor_res); } Tensor my_clone(Tensor t) { @@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) { } void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - Tensor tensor_res = my_clone(to(stack[0])); - stack[0] = from(tensor_res); + Tensor tensor_res = my_clone(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(tensor_res); } @@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) { } void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_zero_(to(stack[0])); - stack[0] = from(res); + auto res = my_zero_(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_amax(Tensor t) { @@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) { } void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_amax(to(stack[0])); - stack[0] = from(res); + auto res = my_amax(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } Tensor my_amax_vec(Tensor t) { @@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) { } void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = my_amax_vec(to(stack[0])); - stack[0] = from(res); + auto res = my_amax_vec(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -464,8 +464,8 @@ void boxed_test_default_constructor( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - bool res = test_default_constructor(to(stack[0])); - stack[0] = from(res); + bool res = test_default_constructor(torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -479,26 +479,26 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { } std::vector my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { - std::array stack = {from(self), from(other)}; + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data()); - return to>(stack[0]); + return torch::stable::detail::to>(stack[0]); } void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - // Why is the following NOT to>(stack[0])? Because calling `to` + // Why is the following NOT torch::stable::detail::to>(stack[0])? Because calling `to` // on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef - // is not owning, so it cannot safely steward the result of the to<>. - auto res = my__foreach_mul(to>(stack[0]), to>(stack[1])); - stack[0] = from(res); + // is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>. + auto res = my__foreach_mul(torch::stable::detail::to>(stack[0]), torch::stable::detail::to>(stack[1])); + stack[0] = torch::stable::detail::from(res); } void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef self, torch::headeronly::HeaderOnlyArrayRef other) { - std::array stack = {from(self), from(other)}; + std::array stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)}; aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data()); } void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - my__foreach_mul_(to>(stack[0]), to>(stack[1])); + my__foreach_mul_(torch::stable::detail::to>(stack[0]), torch::stable::detail::to>(stack[1])); } std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { @@ -512,8 +512,8 @@ std::vector make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) { } void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - auto res = make_tensor_clones_and_call_foreach(to(stack[0]), to(stack[1])); - stack[0] = from(res); + auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to(stack[0]), torch::stable::detail::to(stack[1])); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { @@ -550,8 +550,8 @@ void boxed_test_device_guard( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - int res = test_device_guard(static_cast(to(stack[0]))); - stack[0] = from(res); + int res = test_device_guard(static_cast(torch::stable::detail::to(stack[0]))); + stack[0] = torch::stable::detail::from(res); } int64_t test_device_guard_set_index() { @@ -570,7 +570,7 @@ void boxed_test_device_guard_set_index( uint64_t num_args, uint64_t num_outputs) { int64_t res = test_device_guard_set_index(); - stack[0] = from(res); + stack[0] = torch::stable::detail::from(res); } int64_t test_stream(int32_t device_index) { @@ -586,8 +586,8 @@ void boxed_test_stream( StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { - int64_t res = test_stream(static_cast(to(stack[0]))); - stack[0] = from(res); + int64_t res = test_stream(static_cast(torch::stable::detail::to(stack[0]))); + stack[0] = torch::stable::detail::from(res); } int64_t test_get_current_device_index() { @@ -599,7 +599,7 @@ void boxed_test_get_current_device_index( uint64_t num_args, uint64_t num_outputs) { int64_t res = test_get_current_device_index(); - stack[0] = from(res); + stack[0] = torch::stable::detail::from(res); } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { From 32d30d96cf209809a8ddabe5e219ea915bfe52a3 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Fri, 7 Nov 2025 21:47:59 +0000 Subject: [PATCH 230/651] [ROCm][CI] unconditionally add gfx950, gfx115x to PYTORCH_ROCM_ARCH (#167299) Included gfx950, gfx1150, and gfx1151 unconditionally in PYTORCH_ROCM_ARCH. Removed the ROCm 7.0 version check and refactored the architecture list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167299 Approved by: https://github.com/jeffdaily --- .ci/docker/almalinux/build.sh | 6 +----- .ci/docker/libtorch/build.sh | 6 +----- .ci/docker/manywheel/build.sh | 6 +----- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/.ci/docker/almalinux/build.sh b/.ci/docker/almalinux/build.sh index 885c4440e0e6f..468f9b06418f7 100755 --- a/.ci/docker/almalinux/build.sh +++ b/.ci/docker/almalinux/build.sh @@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in ;; rocm*) BASE_TARGET=rocm - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" - # add gfx950, gfx115x conditionally starting in ROCm 7.0 - if [[ "$ROCM_VERSION" == *"7.0"* ]]; then - PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151" - fi + PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" ;; *) diff --git a/.ci/docker/libtorch/build.sh b/.ci/docker/libtorch/build.sh index c40896cb5499f..76d3e01e1c38f 100755 --- a/.ci/docker/libtorch/build.sh +++ b/.ci/docker/libtorch/build.sh @@ -49,11 +49,7 @@ case ${DOCKER_TAG_PREFIX} in fi BASE_TARGET=rocm GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" - # add gfx950, gfx115x conditionally starting in ROCm 7.0 - if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then - PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151" - fi + PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}" ;; *) diff --git a/.ci/docker/manywheel/build.sh b/.ci/docker/manywheel/build.sh index ac385ce4b29fd..8f9059dc0cc12 100755 --- a/.ci/docker/manywheel/build.sh +++ b/.ci/docker/manywheel/build.sh @@ -87,11 +87,7 @@ case ${image} in MANY_LINUX_VERSION="2_28" DEVTOOLSET_VERSION="11" GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" - # add gfx950, gfx115x conditionally starting in ROCm 7.0 - if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then - PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151" - fi + PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" ;; manylinux2_28-builder:xpu) From 6392b986e7e72dee831696d0bdf543ca9f8609b8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 22:25:01 +0000 Subject: [PATCH 231/651] Revert "[13/N] Apply ruff UP035 rule (#167048)" This reverts commit ea44f12bce3eb05eaa9fa34943a3ffae04647fa5. Reverted https://github.com/pytorch/pytorch/pull/167048 on behalf of https://github.com/donigian due to breaking internal tests D86342860 ([comment](https://github.com/pytorch/pytorch/pull/167048#issuecomment-3505232522)) --- test/dynamo/test_install_free_tensors.py | 4 ++-- test/dynamo/test_python_autograd.py | 6 +----- test/typing/pass/arithmetic_ops.py | 4 ++-- torch/_C/_distributed_c10d.pyi | 3 +-- torch/_dynamo/variables/ctx_manager.py | 4 ++-- torch/_inductor/codegen/pallas.py | 4 +--- torch/_inductor/runtime/caching/config.py | 2 +- torch/distributed/_local_tensor/_c10d.py | 3 ++- 8 files changed, 12 insertions(+), 18 deletions(-) diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index fd9e14c4c3f76..3858b827bd598 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import unittest -from collections.abc import Callable, Sequence -from typing import Any, Union +from collections.abc import Sequence +from typing import Any, Callable, Union import torch import torch._dynamo diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index a6117bb4093a7..a615c653f56c3 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import NamedTuple, Optional, TYPE_CHECKING +from typing import Callable, NamedTuple, Optional import torch import torch._dynamo @@ -7,10 +7,6 @@ from torch._dynamo.testing import CompileCounter, same -if TYPE_CHECKING: - from collections.abc import Callable - - """ This is an example of a pure-python version of autograd implemented by @zdevito. It represents a rather challenging test case for TorchDynamo diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index 14dda1cf39772..f0d6cc6fd9f97 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,5 +1,5 @@ -from typing import TypeAlias, Union -from typing_extensions import assert_type +from typing import Union +from typing_extensions import assert_type, TypeAlias from torch import randn, Tensor diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index b659be9ee119e..f3d96860f5584 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,9 +1,8 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" -from collections.abc import Callable from datetime import timedelta from enum import Enum -from typing import Any, Optional, overload, Union +from typing import Any, Callable, Optional, overload, Union import torch from torch import Tensor diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 318d0e91a0700..b019296d98fcd 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -21,9 +21,9 @@ import inspect import sys import warnings -from collections.abc import Callable, Sequence, Sized +from collections.abc import Callable, Sequence from contextlib import ExitStack -from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union import torch._C from torch._guards import Guard diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 8587368407323..6ee901d19b14f 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -2,7 +2,7 @@ from __future__ import annotations import hashlib -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING import sympy # noqa: TC002 @@ -17,8 +17,6 @@ if TYPE_CHECKING: - from collections.abc import Callable, Sequence - from ..ir import IRNode from ..scheduler import BaseSchedulerNode diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index 14e13f937dbb7..748715d1631ad 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -1,6 +1,6 @@ import os -from collections.abc import Callable from functools import cache, partial +from typing import Callable import torch from torch._environment import is_fbcode diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 0b63330dfafce..c9256543e8977 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -1,8 +1,9 @@ import functools import math import operator -from collections.abc import Callable, Sequence +from collections.abc import Sequence from datetime import timedelta +from typing import Callable import torch from torch._C import ScriptObject From bbf852d87ff527a5cdd9b9ca999356062eadf575 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 7 Nov 2025 22:32:44 +0000 Subject: [PATCH 232/651] Revert "Remove python workaround for ContextDecorator (#167049)" This reverts commit 13d2cc7bd26e32cafff0377dda1c5ddc8d04c4ce. Reverted https://github.com/pytorch/pytorch/pull/167049 on behalf of https://github.com/donigian due to breaking internal tests D86342845 ([comment](https://github.com/pytorch/pytorch/pull/167049#issuecomment-3505251296)) --- torch/autograd/profiler.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 9e2a7b5046dee..fa43af2701171 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -52,7 +52,26 @@ "MemRecordsAcc", ] -from contextlib import ContextDecorator +try: + # Available in Python >= 3.2 + from contextlib import ContextDecorator as _ContextDecorator +except ImportError: + import functools + + class _ContextDecorator: # type: ignore[no-redef] + def __enter__(self): + raise NotImplementedError + + def __exit__(self, exc_type, exc_val, exc_tb): + raise NotImplementedError + + def __call__(self, func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapped # global python state - whether profiler is currently enabled @@ -725,7 +744,8 @@ def createFunctionEventForMemoryEvents(evt): return all_function_events -class record_function(ContextDecorator): +# pyrefly: ignore [invalid-inheritance] +class record_function(_ContextDecorator): """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. Label will only appear if CPU activity tracing is enabled. From ea6b0b5d0fb3a0a223b1070197bb57bde2e0e564 Mon Sep 17 00:00:00 2001 From: Sumantro Mukherjee Date: Fri, 7 Nov 2025 23:00:04 +0000 Subject: [PATCH 233/651] add missing cpp standard lib in HeaderOnlyArrayRef.h (#167337) Fixes #167315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167337 Approved by: https://github.com/janeyx99 --- torch/headeronly/util/HeaderOnlyArrayRef.h | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/headeronly/util/HeaderOnlyArrayRef.h b/torch/headeronly/util/HeaderOnlyArrayRef.h index 2387578ab8f5f..751ffef32bb1d 100644 --- a/torch/headeronly/util/HeaderOnlyArrayRef.h +++ b/torch/headeronly/util/HeaderOnlyArrayRef.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include From 09705ca9b21eb4b98c0f77f9703ac78f332134fb Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 7 Nov 2025 11:46:14 -0800 Subject: [PATCH 234/651] [dynamo][guards] Fix mem leak in tensor subclass metadata guard (#167352) Use cls instead of the object. Earlier the metadata guard was holding on to the Dtensor causing mem leak. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167352 Approved by: https://github.com/Skylion007 --- torch/_dynamo/guards.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 2792ce512d8a1..67995e93bab77 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2141,9 +2141,10 @@ def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None: original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) if hasattr(value, "__metadata_guard__"): verify_guard_fn_signature(value) + cls = type(value) def metadata_checker(x: Any) -> bool: - return value.__metadata_guard__( + return cls.__metadata_guard__( original_metadata, x.__tensor_flatten__()[1] ) From c7007e758478fcac4ed9bb0479d73d6e397e8b8a Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Fri, 7 Nov 2025 23:06:56 +0000 Subject: [PATCH 235/651] Update Kineto Submodule (#167343) Summary: Title Test Plan: CI Differential Revision: D86538778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167343 Approved by: https://github.com/Skylion007, https://github.com/aaronenyeshi --- third_party/kineto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/kineto b/third_party/kineto index 6fcbc53d33dd2..57c561f4ca89e 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 6fcbc53d33dd275c0aba1e5d7701d471b7f6eeb3 +Subproject commit 57c561f4ca89ecd8dec1ce4fa1fa60e7cfdb555b From 5062abe4e7e0896ddfce2a5b7efe0136bb6df193 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 7 Nov 2025 12:20:16 -0800 Subject: [PATCH 236/651] [CI][serialization] Fix exception regexes with Python-3.14 (#167333) Not sure why, but running some tests (for example `test_weights_only_safe_globals_build`) with `pytest` in 3.14 makes global name `test_serialization.ClassThatUsesBuildInstruction` instead of expected `__main__.ClassThatUsesBuildInstruction` Also, change expected exception type from `AttributeError` to `PicklingError` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167333 Approved by: https://github.com/atalman --- test/test_serialization.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 20f74b6dc6a21..d292c13993cfe 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1281,7 +1281,7 @@ def test_weights_only_safe_globals_newobj(self): torch.save(p, f) f.seek(0) with self.assertRaisesRegex(pickle.UnpicklingError, - "GLOBAL __main__.Point was not an allowed global by default"): + f"GLOBAL {__name__}.Point was not an allowed global by default"): torch.load(f, weights_only=True) f.seek(0) with torch.serialization.safe_globals([Point]): @@ -1300,7 +1300,7 @@ def fake_set_state(obj, *args): torch.save(c, f) f.seek(0) with self.assertRaisesRegex(pickle.UnpicklingError, - "GLOBAL __main__.ClassThatUsesBuildInstruction was not an allowed global by default"): + f"GLOBAL {__name__}.ClassThatUsesBuildInstruction was not an allowed global by default"): torch.load(f, weights_only=True) try: with torch.serialization.safe_globals([ClassThatUsesBuildInstruction]): @@ -1330,7 +1330,7 @@ def test_weights_only_safe_globals_build_with_slots(self, slots): torch.save(obj, f) f.seek(0) with self.assertRaisesRegex(pickle.UnpicklingError, - f"GLOBAL __main__.{obj_cls.__name__} was not an allowed global by default"): + f"GLOBAL {__name__}.{obj_cls.__name__} was not an allowed global by default"): torch.load(f, weights_only=True) f.seek(0) @@ -4501,9 +4501,10 @@ def fn(t): # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx if not materialize_fake: ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) + exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError with self.assertRaisesRegex( - AttributeError, - "Can't (get|pickle) local object 'WeakValueDictionary.__init__..remove'" + exc, + "Can't (get|pickle) local object (.remove" ): with skip_data(), BytesIOContext() as f: torch.save(ft, f) From 69ecb562e734c0729e7e0581565312fc6487e682 Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Fri, 7 Nov 2025 23:27:57 +0000 Subject: [PATCH 237/651] [PT2 Compiler] Add annotation for dynamo disabled callables (#166341) Summary: To make torch.export compatible with PT2 compile (which is done on top of exported model) we need to store torch._dynamo.disable attributes in exported model and later restore this after unflattening of exported model. This diff will add annotations to all nodes with torch._dynamo.disable, which will be preserved during exporting. Test Plan: ``` buck test mode/opt caffe2/test:test_export -- 'test_dynamo_disable_annotations' ``` https://www.internalfb.com/intern/testinfra/testrun/6473924770741560 Differential Revision: D85302730 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166341 Approved by: https://github.com/williamwen42, https://github.com/angelayi --- test/dynamo/test_decorators.py | 37 +++++++++++++++++++++++++++++++++ test/export/test_export.py | 13 +++++++----- torch/_dynamo/eval_frame.py | 13 +++++++++++- torch/_dynamo/external_utils.py | 4 ++++ torch/_export/utils.py | 14 +++++++++++-- 5 files changed, 73 insertions(+), 8 deletions(-) diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 68a10360284dc..09936044bd450 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -2155,6 +2155,43 @@ def forward(self, inp): torch.compile(model) torch.compile(other_model) + def test_dynamo_disable_annotations(self): + class SimpleModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("buffer", torch.rand(2, 2)) + + @torch._dynamo.disable() + def f1(self, x) -> torch.Tensor: + return x + self.buffer + 1 + + @torch._dynamo.disable() + def f2(self, x) -> torch.Tensor: + return x + self.buffer + 2 + + def forward(self, x) -> torch.Tensor: + return self.f1(x) + self.f2(x) + + model = SimpleModel() + inp = torch.rand(2, 2) + with torch.fx.traceback.preserve_node_meta(): + exported_model = torch.export.export(model, (inp,)) + graph = exported_model.graph_module.graph + found_f1 = False + found_f2 = False + for node in graph.nodes: + if "custom" in node.meta: + if "_torchdynamo_disable_method" in node.meta["custom"]: + if node.meta["custom"]["_torchdynamo_disable_method"] == "f1": + found_f1 = True + elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2": + found_f2 = True + self.assertTrue(found_f1) + self.assertTrue(found_f2) + model.forward = torch._dynamo.disable(model.forward, recursive=False) + with self.assertRaises(RuntimeError): + exported_model = torch.export.export(model, (inp,)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/export/test_export.py b/test/export/test_export.py index 25f1cec03bd7c..c7848eb3d69de 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -742,11 +742,14 @@ def forward(self, x, y): self.assertExpectedInline( str(custom_metadata), """\ -('call_function', 'cat', {'moo': 0}) -('call_function', 'item', {'moo': 0}) -('call_function', 'ge_1', {'moo': 0}) -('call_function', '_assert_scalar_default', {'moo': 0}) -('call_function', 'mul', {'moo': 0})""", +('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'}) +('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'}) +('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0}) +('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0}) +('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0}) +('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0}) +('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0}) +('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""", ) @requires_gpu diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e93e7ace7395e..9c9076f5a99c0 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -78,7 +78,7 @@ _RelaxedConstraint, Constraint, ) -from torch.fx import GraphModule +from torch.fx import GraphModule, traceback as fx_traceback from torch.fx.experimental._dynamism import ( clone_and_convert_to_meta, track_dynamism_across_examples, @@ -1134,6 +1134,17 @@ def _fn(*args: Any, **kwargs: Any) -> Any: try: _maybe_set_eval_frame(_callback_from_stance(self.callback)) try: + if torch.compiler.is_exporting(): + with fx_traceback.annotate( + { + "_torchdynamo_disable": True, + "_torchdynamo_disable_recursive": True, + "_torchdynamo_disable_method": getattr( + fn, "__name__", type(fn).__name__ + ), + } + ): + return fn(*args, **kwargs) return fn(*args, **kwargs) finally: set_eval_frame(None) diff --git a/torch/_dynamo/external_utils.py b/torch/_dynamo/external_utils.py index fd21f57d8b865..10422a3e2b82b 100644 --- a/torch/_dynamo/external_utils.py +++ b/torch/_dynamo/external_utils.py @@ -196,6 +196,10 @@ def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: # this function is in external_utils so that convert_frame doesn't skip it. @functools.wraps(fn) def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + if torch.compiler.is_exporting(): + raise RuntimeError( + "Non-recursive torch.compiler.disable is not supported with torch.export." + ) return fn(*args, **kwargs) return nonrecursive_disable_wrapper diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 648e32758e5fa..74230e4a5ed55 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1097,9 +1097,19 @@ def _extract_pytree_key(x): node.name = node.target = name_map[node.name] if node.name in custom_meta: if node.meta.get("custom") is None: - node.meta["custom"] = custom_meta[node.name] + node.meta["custom"] = {} else: - assert node.meta["custom"] == custom_meta[node.name] + # Assert if any existing key has different value + for k, v in node.meta["custom"].items(): + if ( + k in custom_meta[node.name] + and v != custom_meta[node.name][k] + ): + raise AssertionError( + f"Mismatch in custom metadata for key {k}. Value in " + f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}." + ) + node.meta["custom"].update(custom_meta[node.name]) # if the constant obj is an input, we also need to update meta["val"] # because this is created before the placeholder naming pass if isinstance(node.meta["val"], CustomObjArgument): From 70f5f55abfa5853165cf27f15d557e5c56978ec5 Mon Sep 17 00:00:00 2001 From: Blaine Burton Rister <145300525+blaine-rister@users.noreply.github.com> Date: Fri, 7 Nov 2025 23:48:51 +0000 Subject: [PATCH 238/651] [Inductor-FX] Allocate tensors on device type instead of indexed device (#167358) # Problem The FX backend currently allocates tensors on an exact device index, such as `"cuda:0"`. In contrast, the Python backend allocates on a device type, such as `"cuda"`. This avoids edge cases where fake tensor propagation can fail due to mismatched devices. # Fix Allocate tensors on `device.type` instead of the device. # Test plan Added a CI test passing in sample inputs on an indexed device, and checking that the output device in the generated FX graph is not indexed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167358 Approved by: https://github.com/mlazos, https://github.com/nandesuka, https://github.com/eellison --- test/inductor/test_fxir_backend.py | 18 ++++++++++++++++++ torch/_inductor/codegen/wrapper_fxir.py | 3 ++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index 3ebc7f04b3887..f9e84284f0d8d 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -148,6 +148,24 @@ def test_basic(self): args = [torch.randn(8, device=self.device) for _ in range(2)] self._compile_and_check(torch.add, args) + def test_device_type(self): + """ + Test that we allocate on a device type instead of a specific index. + """ + # Pass in a tensor on an indexed device. + device_runtime = getattr(torch, self.device) + indexed_device = torch.device(self.device, device_runtime.current_device()) + args = [torch.randn(8, device=indexed_device) for _ in range(2)] + (gm,) = self._compile_and_check(torch.add, args) + (empty_strided,) = gm.graph.find_nodes( + op="call_function", target=torch.empty_strided + ) + + # Check that the device of the output allocation is not indexed. + output_device = torch.device(empty_strided.kwargs["device"]) + self.assertIs(output_device.index, None) + self.assertEqual(output_device.type, indexed_device.type) + def test_multiple_kernels(self): def foo(x, y): return x.sum() + y.sum() diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 0659bee04d689..02c498d6debce 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -678,6 +678,7 @@ def _generate_allocate(self, line: WrapperLine) -> None: assert name not in V.graph.removed_buffers device = buffer.get_device() + assert device dtype = buffer.get_dtype() shape = self._generate_sym_nodes(buffer.get_size()) stride = self._generate_sym_nodes(buffer.get_stride()) @@ -685,7 +686,7 @@ def _generate_allocate(self, line: WrapperLine) -> None: node = self.gm.graph.call_function( torch.empty_strided, args=(shape, stride), - kwargs={"dtype": dtype, "device": device}, + kwargs={"dtype": dtype, "device": device.type}, ) assert name node.name = name From fbc0bd2e90f9a92802e9704613f92d99d6bccff1 Mon Sep 17 00:00:00 2001 From: Anshul Sinha Date: Thu, 6 Nov 2025 13:29:34 -0800 Subject: [PATCH 239/651] [DTensor][be] getting rid of unneccesary Partial check for norm functions (#167247) **Summary:** While the implementation is correct, these checks are just a subset of the Partial placement checks that are done in https://github.com/pytorch/pytorch/pull/165962. This means for ops aten.linalg_vector_norm.default and aten._foreach_norm.Scalar, we're unnecessarily checking for Partial placements twice. **Test Cases** 1. pytest test/distributed/tensor/test_math_ops.py -k test_vector_norm_partial 2. pytest test/distributed/tensor/test_math_ops.py -k test_foreach_norm_partial 3. pytest test/distributed/tensor/test_math_ops.py -k test_partial_reduction_ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/167247 Approved by: https://github.com/XilunWu --- torch/distributed/tensor/_ops/_math_ops.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 45a786b9058e2..545895c83b6eb 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -441,15 +441,11 @@ def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy: keepdim = args_schema[3] if len(args_schema) > 3 else False dims = _infer_reduction_dims(dim, input_strategy.ndim) reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims - reduction_linear = all( - all(not p.is_partial() for p in op_spec.output_spec.placements) - for op_spec in input_strategy.strategies - ) return common_reduction_strategy( input_strategy, reduce_dims, keep_dim=cast(bool, keepdim), - reduction_linear=reduction_linear, + reduction_linear=True, reduction_op=NormReduction(norm_type), ) @@ -472,14 +468,10 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy: if not isinstance(op_strategy, OpStrategy): raise AssertionError(f"Expected OpStrategy, got {type(op_strategy)}") reduce_dims = list(range(op_strategy.ndim)) - reduction_linear = all( - all(not p.is_partial() for p in op_spec.output_spec.placements) - for op_spec in op_strategy.strategies - ) output_strategy = common_reduction_strategy( op_strategy, reduce_dims, - reduction_linear=reduction_linear, + reduction_linear=True, reduction_op=NormReduction(norm_type), ) output_tuple_strategy_children.append(output_strategy) From d865156967db65625bfe1d5474bab377713d41cc Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 7 Nov 2025 15:21:33 -0800 Subject: [PATCH 240/651] [dynamo][hops] Overwrite proxy of the original VT to the subgraph outputs (#167160) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167160 Approved by: https://github.com/zou3519 --- test/dynamo/test_activation_checkpointing.py | 55 +++++++++++++++++++ torch/_dynamo/variables/higher_order_ops.py | 58 +++++++++++++++++--- 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index df5906c091ad3..d6c0feac19ae1 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1672,6 +1672,61 @@ def fn(x): # The mutation is not reapplied in the backward because the flag was on. self.assertEqual(counter, 1) + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_list_mutation(self): + def gn(x, z): + out = x.sin() + z.append(out) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out + + def fn(x): + z = [] + + out1, out2 = torch.utils.checkpoint.checkpoint( + gn, + x, + z, + use_reentrant=False, + ) + + return out1, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + + @unittest.expectedFailure + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_nonlocal_list_mutation_hidden(self): + def gn(x, z): + out = x.sin() + z.append(out) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) + + def fn(x): + z = [] + + out1 = torch.utils.checkpoint.checkpoint( + gn, + x, + z, + use_reentrant=False, + ) + + return out1, z[0] + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + devices = ["cuda", "hpu"] instantiate_device_type_tests( diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 15f88f45bf7c5..dc7a7d13908a8 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -247,7 +247,7 @@ def inline_call(*args, **kwargs): def _call_function_and_unflatten_output( - tx, fn, args, kwargs, flat_example_value, ret_spec + tx, fn, args, kwargs, flat_example_value, ret_spec, body_r ): from .builder import wrap_fx_proxy @@ -263,6 +263,21 @@ def _call_function_and_unflatten_output( example_value=flat_example_value, ) + # wrap_fx_proxy creates fresh variable trackers. However, the main program + # after the speculate subgraph can still use the original tensor vts that + # are still pointing to the nodes present in the subgraph. So, we reproxify + # the original tensor vts with the subgraph outputs. This way, whenever the + # outer graph uses an original vt, it uses the subgraph output. + if body_r is not None: + for orig_vt, subgraph_vt in zip(body_r.items, flat_variable.items): + if isinstance( + orig_vt, (variables.SymNodeVariable, variables.TensorVariable) + ): + assert isinstance( + subgraph_vt, (variables.SymNodeVariable, variables.TensorVariable) + ) + orig_vt.proxy = subgraph_vt.proxy + if ret_spec.masks_to_filter_const_values: from torch._dynamo.external_utils import insert_const_values_with_mask @@ -572,6 +587,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: {}, None, body_treespec, + body_r, ) @@ -1535,6 +1551,7 @@ def speculate_branch(branch): {}, None, true_spec, + true_r, ) @@ -1858,6 +1875,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): {}, None, OutputSpec(xs_treespec), + None, ) @@ -2090,7 +2108,13 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) return _call_function_and_unflatten_output( - tx, torch.ops.higher_order.scan, p_args, {}, None, _combine_spec + tx, + torch.ops.higher_order.scan, + p_args, + {}, + None, + _combine_spec, + None, ) @@ -2213,7 +2237,7 @@ def _call_function( ) return _call_function_and_unflatten_output( - tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec + tx, torch.ops.higher_order.map_impl, p_args, {}, None, body_spec, body_r ) @@ -2419,7 +2443,13 @@ def _call_function( ) return _call_function_and_unflatten_output( - tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec + tx, + self.value, + tuple(p_args), + p_kwargs, + flat_example_value, + treespec, + body_r, ) @@ -2506,7 +2536,7 @@ def call_function( body_r.as_proxy(), ) return _call_function_and_unflatten_output( - tx, self.value, proxy_args, {}, example_value, treespec + tx, self.value, proxy_args, {}, example_value, treespec, body_r ) @@ -2601,7 +2631,7 @@ def call_function( ) return _call_function_and_unflatten_output( - tx, self.value, proxy_args, {}, example_value, treespec + tx, self.value, proxy_args, {}, example_value, treespec, body_r ) @@ -2674,7 +2704,7 @@ def _call_function( ) return _call_function_and_unflatten_output( - tx, self.value, p_args, p_kwargs, flat_example_value, treespec + tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r ) @@ -2793,6 +2823,7 @@ def _call_function( {}, flat_example_value, ret_spec, + ret_val, ) @@ -2860,6 +2891,7 @@ def _call_function( checkpoint_kwargs, example_value, out_spec, + _body_r, ) @@ -2913,6 +2945,7 @@ def _call_function( {}, example_value, out_spec, + _body_r, ) @@ -3652,7 +3685,13 @@ def _call_function( p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} return _call_function_and_unflatten_output( - tx, self.value, p_args, p_kwargs, flat_example_value, treespec + tx, + self.value, + p_args, + p_kwargs, + flat_example_value, + treespec, + body_r, ) @@ -3768,6 +3807,7 @@ def _call_function( p_kwargs, flat_example_value, treespec, + body_r, ) @@ -3991,7 +4031,7 @@ def make_error_msg(*args): # Step 5: Install local_map subgraph p_kwargs = {key: value.as_proxy() for key, value in kwargs.items()} out = _call_function_and_unflatten_output( - tx, self.value, p_args, p_kwargs, flat_example_value, treespec + tx, self.value, p_args, p_kwargs, flat_example_value, treespec, body_r ) # Step 6: Restore inputs and outputs to global shapes From 2325c511e7c4150c0f127485dbde152caf0cf2db Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 7 Nov 2025 15:21:34 -0800 Subject: [PATCH 241/651] [dynamo] Make sym node vt creation via SymNodeVariable create (#167189) This will help in the next PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167189 Approved by: https://github.com/williamwen42, https://github.com/zou3519 ghstack dependencies: #167160 --- torch/_dynamo/variables/builder.py | 6 +++--- torch/_dynamo/variables/constant.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9733bc946c308..d54d586cd08fc 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1169,7 +1169,7 @@ def build_key_value(i, k, v): f"{sym_expr} is not a basic Symbol." ) self.tx.output.tracked_fakes.append(TrackedFake(node, source, None)) - return SymNodeVariable(sym_node_proxy, node) + return SymNodeVariable.create(self.tx, sym_node_proxy, node) elif is_torch_sym(value): # Note: this doesn't handle nested symints. # For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo. @@ -2454,7 +2454,7 @@ def wrap_symint( sym_expr = wrapped_value.node.expr assert isinstance(sym_expr, sympy.Symbol), f"{sym_expr} is not a basic Symbol." self.tx.output.root_tracer.bound_symbols[sym_expr] = proxy - unspec_var = SymNodeVariable(proxy, wrapped_value, **options) + unspec_var = SymNodeVariable.create(self.tx, proxy, wrapped_value, **options) self.tx.output.unspec_variable_map[self.name] = unspec_var if not is_constant_source(self.get_source()): @@ -3002,7 +3002,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): tx.output.current_tracer.track_produced_symints(example_value, proxy) set_example_value(proxy.node, example_value) - return SymNodeVariable(proxy, example_value, **options) + return SymNodeVariable.create(tx, proxy, example_value, **options) elif ( isinstance(example_value, torch.Stream) and proxy.node.target is get_external_object_by_index diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 86d3d87e1f8be..73d53cbca402b 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -182,9 +182,9 @@ def call_method( if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. - return variables.SymNodeVariable(self.as_proxy(), self.value).call_method( - tx, name, args, kwargs - ) + return variables.SymNodeVariable.create( + tx, self.as_proxy(), self.value + ).call_method(tx, name, args, kwargs) try: const_args = [a.as_python_constant() for a in args] From 31d6d3ef5cd85ea326090079abfae6c2e66021b2 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Fri, 7 Nov 2025 13:28:44 -0800 Subject: [PATCH 242/651] [easy] Add new torch/csrc/stable/c/shim.h to existing nitpick (#167367) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167367 Approved by: https://github.com/janeyx99, https://github.com/malfet --- .github/nitpicks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/nitpicks.yml b/.github/nitpicks.yml index 1d08a36abf1d5..e3fe5d4725587 100644 --- a/.github/nitpicks.yml +++ b/.github/nitpicks.yml @@ -10,3 +10,4 @@ pathFilter: - 'torch/csrc/inductor/aoti_torch/c/*' - 'torch/csrc/inductor/aoti_torch/generated/*' + - 'torch/csrc/stable/c/*' From 4957ae5838d0ca1506c68b4c243f2fadfb0d89d2 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 7 Nov 2025 11:56:49 -0800 Subject: [PATCH 243/651] Add API to annotate disjoint backward and handle in AC (#166536) This adds zero-bubble / DualPipeV support for (S)AC Before: - AC will always retrigger recompute upon every distinct backward. After: - Any checkpointed regions encountered by backward under the same instance of this context manager will only trigger recompute at most once, even if there are multiple calls to backward. - Backward calls under the same instance of this context manager must execute over non-overlapping regions of the backward graph even if retain_graph=True. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166536 Approved by: https://github.com/albanD --- c10/core/AutogradState.h | 14 +++++- test/inductor/test_compiled_autograd.py | 1 + test/test_autograd.py | 56 +++++++++++++++++++++++ torch/_C/__init__.pyi.in | 3 ++ torch/csrc/autograd/init.cpp | 29 ++++++++++++ torch/utils/checkpoint.py | 60 ++++++++++++++++++++++--- 6 files changed, 155 insertions(+), 8 deletions(-) diff --git a/c10/core/AutogradState.h b/c10/core/AutogradState.h index ad168b8c05987..d2b9cc080413d 100644 --- a/c10/core/AutogradState.h +++ b/c10/core/AutogradState.h @@ -1,6 +1,8 @@ #pragma once +#include #include +#include namespace c10 { @@ -15,7 +17,8 @@ struct C10_API AutogradState { bool inference_mode, bool fw_grad_mode, bool multithreading_enabled) - : grad_mode_(grad_mode), + : graph_exec_group_(std::nullopt), + grad_mode_(grad_mode), inference_mode_(inference_mode), fw_grad_mode_(fw_grad_mode), multithreading_enabled_(multithreading_enabled), @@ -41,6 +44,10 @@ struct C10_API AutogradState { view_replay_enabled_ = view_replay_enabled; } + void set_graph_exec_group(std::optional group) { + graph_exec_group_ = std::move(group); + } + bool get_grad_mode() const { return grad_mode_; } @@ -61,7 +68,12 @@ struct C10_API AutogradState { return view_replay_enabled_; } + const std::optional& get_graph_exec_group() const { + return graph_exec_group_; + } + private: + std::optional graph_exec_group_; bool grad_mode_ : 1; bool inference_mode_ : 1; bool fw_grad_mode_ : 1; diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 3001f86f4cfce..bfd789136d70d 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -5222,6 +5222,7 @@ def wrap_test_class(orig_cls): "test_reentrant_with_callbacks_both_depths", # queue_callback "test_reentrant_with_callbacks_depth_0", # queue_callback "test_reentrant_with_callbacks_depth_1", # queue_callback + "test_checkpoint_graph_execution_group", # Attempted to call function marked as skipped "test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd "test_post_accumulate_grad_hook_ordering", # accuracy error diff --git a/test/test_autograd.py b/test/test_autograd.py index 6c3e250df7c7c..4926697d1d1be 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7364,6 +7364,62 @@ def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self) ): checkpoint_sequential(modules_list, 3, a) + @skipIfTorchDynamo("GraphExecGroup does not support compile") + def test_checkpoint_graph_execution_group(self): + def run(use_graph_execution_group): + counter = [0] + + def fn(x): + counter[0] += 1 + y = x.sin().cos() + z = y.sin().cos() + return y, z + + x = torch.randn(3, 3, requires_grad=True) + + y, z = checkpoint(fn, x, use_reentrant=False) + + group = torch.utils.checkpoint.GraphExecGroup() + + ctx = contextlib.nullcontext() + if use_graph_execution_group: + ctx = group + + with ctx: + (grad_y,) = torch.autograd.grad( + z, inputs=(y,), grad_outputs=(torch.ones(3, 3),) + ) + + (grad_x,) = torch.autograd.grad( + y, + inputs=(x,), + grad_outputs=(grad_y,), + ) + + if use_graph_execution_group: + self.assertEqual(counter[0], 2) + else: + self.assertEqual(counter[0], 3) + + run(use_graph_execution_group=True) + run(use_graph_execution_group=False) + + # Test the not actually disjoint case (using retain_graph=True since + # otherwise autograd itself will catch this) + def fn(x): + return x.sin().cos() + + x = torch.randn(3, 3, requires_grad=True) + out = checkpoint(fn, x, use_reentrant=False) + with torch.utils.checkpoint.GraphExecGroup(): + # Under this context, we will enforce that two backward are disjoint + # even if retain_graph=True. + out.sum().backward(retain_graph=True) + with self.assertRaisesRegex( + RuntimeError, "Performing two backward calls that overlap" + ): + out.sum().backward() + def test_checkpoint_detects_non_determinism(self): def save_3_tensors(x): out = x.sin().exp() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 4acffdb1997f9..16d71cd0abb2e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -69,6 +69,7 @@ from torch.types import ( Storage, ) from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.checkpoint import GraphExecGroup # This module is defined in torch/csrc/Module.cpp @@ -1491,6 +1492,8 @@ def _is_multithreading_enabled() -> _bool: ... def _set_multithreading_enabled(enabled: _bool) -> None: ... def _set_view_replay_enabled(enabled: _bool) -> None: ... def _is_view_replay_enabled() -> _bool: ... +def _set_graph_exec_group(group: GraphExecGroup | None) -> None: ... +def _get_graph_exec_group() -> GraphExecGroup | None: ... def _enter_dual_level() -> _int: ... def _exit_dual_level(level: _int) -> None: ... def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 7cfb935942046..a13cc70270ccb 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1218,6 +1218,33 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) { END_HANDLE_TH_ERRORS } +static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) { + HANDLE_TH_ERRORS + if (obj == Py_None) { + c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt); + } else { + Py_INCREF(obj); + c10::AutogradState::get_tls_state().set_graph_exec_group( + c10::SafePyObject(obj, getPyInterpreter())); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* get_graph_exec_group(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + const auto& group = + c10::AutogradState::get_tls_state().get_graph_exec_group(); + if (group.has_value()) { + PyObject* obj = group->ptr(getPyInterpreter()); + Py_INCREF(obj); + return obj; + } else { + Py_RETURN_NONE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS if (c10::InferenceMode::is_enabled()) { @@ -1598,6 +1625,8 @@ static PyMethodDef methods[] = { castPyCFunctionWithKeywords(set_view_replay_enabled), METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_set_graph_exec_group", set_graph_exec_group, METH_O, nullptr}, + {"_get_graph_exec_group", get_graph_exec_group, METH_NOARGS, nullptr}, {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 9b10c4d192d4e..b74e4d01da060 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -33,6 +33,7 @@ "SelectiveCheckpointContext", "create_selective_checkpoint_contexts", "SAC_IGNORED_OPS", + "GraphExecGroup", ] _DEFAULT_DETERMINISM_MODE = "default" @@ -1072,7 +1073,7 @@ class _StopRecomputationError(Exception): class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): - def __init__(self, target_frame_ref: ReferenceType, gid: int) -> None: + def __init__(self, target_frame_ref: ReferenceType, gid: Union["GraphExecGroup", int]) -> None: def pack_hook(x): x = x.detach() if x.requires_grad else x target_frame = target_frame_ref() @@ -1145,10 +1146,14 @@ def pack_hook(x): return holder def unpack_hook(holder): - gid = torch._C._current_graph_task_id() - if gid == -1: - # generate a temporary id if we trigger unpack outside of a backward call - gid = int(uuid.uuid4()) + # First check if we're inside a GraphExecGroup context + gid: Union[GraphExecGroup, None, int] = GraphExecGroup._get_current_group() + if gid is None: + # Fallback to using the current graph task id + gid = torch._C._current_graph_task_id() + if gid == -1: + # generate a temporary id if we trigger unpack outside of a backward call + gid = int(uuid.uuid4()) if not frame.is_recomputed[gid]: ctx = frame.input_saver.grad_fn @@ -1168,10 +1173,17 @@ def unpack_hook(holder): _internal_assert(gid in holder.handles) if holder.handles[gid] is None: + extra = "" + if torch._C._get_graph_exec_group() is not None: + extra = ( + "Performing two backward calls that overlap (i.e. require the same " + "saved activation in order to compute gradients) is not allowed while " + "under the torch.utils.checkpoint.GraphExecGroup context. " + ) raise CheckpointError( "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " - "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " - "so only once. Otherwise please open an issue with details on your use case." + f"unpacked once. {extra}If you are calling ctx.saved_tensors in backward, make sure " + "to do so only once. Otherwise please open an issue with details on your use case." ) _internal_assert(holder.handles[gid] in frame.recomputed[gid]) ret = frame.recomputed[gid][holder.handles[gid]] @@ -1594,6 +1606,40 @@ def recompute_fn(*inputs) -> None: return + +class GraphExecGroup: + """Any checkpointed regions encountered by backward under the same instance + of this context manager will trigger recompute at most once, even if + there are multiple calls to backward. + + Backward calls under the same instance of this context manager must execute + over non-overlapping regions of the backward graph even if retain_graph=True. + In particular, any two backward call cannot use the same saved activation for + gradient computation. + + .. note:: + This context manager only affects checkpoint with use_reentrant=False, and + is a no-op otherwise. + """ + + def __enter__(self) -> "GraphExecGroup": + if torch._C._get_graph_exec_group() is not None: + raise RuntimeError( + "GraphExecGroup contexts cannot be nested. " + f"Already inside group {torch._C._get_graph_exec_group()}" + ) + torch._C._set_graph_exec_group(self) + return self + + def __exit__(self, *args: object) -> None: + torch._C._set_graph_exec_group(None) + + @classmethod + def _get_current_group(cls) -> Optional["GraphExecGroup"]: + # Private API to be used by utils like AC + return torch._C._get_graph_exec_group() + + # Note: [compiled autograd and checkpoint unpack hook] # When tracing via compiled autograd, this hook will be visible to the # compiler if the forward of this checkpointed region ran in eager. From d18c742779114c577ad01eca78c72613a9f41216 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Fri, 7 Nov 2025 10:29:15 -0800 Subject: [PATCH 244/651] [HOP][print]Add make_fx for the proxy with graph module print (#166920) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166920 Approved by: https://github.com/angelayi ghstack dependencies: #166660 --- test/higher_order_ops/test_print.py | 51 +++++++++++++++++++++++++++++ torch/_higher_order_ops/print.py | 16 ++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py index aef538854864f..28d70dfd121c6 100644 --- a/test/higher_order_ops/test_print.py +++ b/test/higher_order_ops/test_print.py @@ -4,6 +4,7 @@ import torch from torch._dynamo.utils import counters +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import run_tests, TestCase @@ -39,6 +40,56 @@ def f(x): self.assertEqual(printed_output, "moo 1 2") + fx_f = make_fx(f)(x) + new_inp = torch.randn(3, 3) + + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + fx_f(new_inp) + ori_printed_output = mock_stdout.getvalue().strip() + + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(new_inp) + fx_printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(ori_printed_output, fx_printed_output) + + def test_print_with_proxy_graph(self): + class M(torch.nn.Module): + def forward(self, x): + torch._higher_order_ops.print("moo {x} {y}", x=1, y=2) + torch._higher_order_ops.print("moo {x}", x=x) + res = x + x + torch._higher_order_ops.print("moo {x} {y}", x=1, y=2) + torch._higher_order_ops.print("yeehop {x}", x=x.shape[0]) + return (res,) + + inputs = (torch.randn(3),) + + # Without functionalization, print should just appear in the graph directly + gm = make_fx(M(), tracing_mode="symbolic")(*inputs) + + self.assertExpectedInline( + str(gm.code).strip(), + """\ +def forward(self, arg0_1): + print_1 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_1 = None + print_2 = torch.ops.higher_order.print('moo {x}', x = arg0_1); print_2 = None + add = torch.ops.aten.add.Tensor(arg0_1, arg0_1) + print_3 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_3 = None + sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0); arg0_1 = None + print_4 = torch.ops.higher_order.print('yeehop {x}', x = sym_size_int); sym_size_int = print_4 = None + return (add,)""", + ) + + new_inp = torch.randn(4) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + gm( + new_inp, + ) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4") + if __name__ == "__main__": run_tests() diff --git a/torch/_higher_order_ops/print.py b/torch/_higher_order_ops/print.py index 5a14ef23aa24e..16e62532d8ecd 100644 --- a/torch/_higher_order_ops/print.py +++ b/torch/_higher_order_ops/print.py @@ -3,6 +3,7 @@ import torch import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode class Print(HigherOrderOperator): @@ -26,6 +27,15 @@ def __call__(self, format_str: str, **kwargs: object) -> object: print = Print() +@print.py_impl(ProxyTorchDispatchMode) +# pyre-ignore +def print_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, format_str: str, **kwargs: object +) -> None: + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) # type: ignore[union-attr] # noqa: F841 + mode.tracer.create_proxy("call_function", print, (format_str,), proxy_kwargs) + + @print.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) # pyre-ignore def print_cpu(format_str: str, **kwargs: object) -> None: @@ -40,5 +50,9 @@ def print_cpu(format_str: str, **kwargs: object) -> None: kwargs, lambda a: isinstance(a, tuple(map_types.keys())), ) - # Use built-in print to avoid recursion with the HOP print + # Use built-in print to avoid recursion with the HOP print builtins.print(format_str.format(**new_kwargs)) + + +print.fallthrough(torch._C.DispatchKey.AutogradCPU) +print.fallthrough(torch._C.DispatchKey.AutogradCUDA) From c45c966031e87214d6641d90fd1985a72d56a297 Mon Sep 17 00:00:00 2001 From: Colin L Reliability Rice Date: Sat, 8 Nov 2025 01:12:18 +0000 Subject: [PATCH 245/651] subproc_pool: Fix quiesce waitcounter (#167350) Summary: I was inspecting running jobs, and the quiesce waitcounter wasn't showing up. Turns out this was a bad copy paste. Test Plan: Primarily inspection Reviewed By: masnesral Differential Revision: D86457409 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167350 Approved by: https://github.com/aorenste, https://github.com/masnesral --- torch/_inductor/compile_worker/subproc_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 2bc87e6f3eb95..b0e0d4ba58495 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -321,7 +321,7 @@ def quiesce(self) -> None: self._send(MsgHeader.QUIESCE) if self.quiesce_waitcounter is None: self.quiesce_waitcounter = _WaitCounter( - "pytorch.wait_counter.subproc_pool.running" + "pytorch.wait_counter.subproc_pool.quiesced" ).guard() self.quiesce_waitcounter.__enter__() From 7fd15aa2bd1c9a85c0887f6cebbfd5ec4023e102 Mon Sep 17 00:00:00 2001 From: Shyamal Shah Date: Sat, 8 Nov 2025 01:28:46 +0000 Subject: [PATCH 246/651] Additional fix on top of D85172267 (#167267) (#167279) Summary: It seems D80948073 has caused some issue on a lowering pkg built on trunk: https://fburl.com/mlhub/o6p60pno error log: P2001933683 which we were able to lower successfully in older ien pkg: https://fburl.com/mlhub/1ro094zo D85172267 fixed this issue for the if conditional, but issue still exists for the else conditional. Logic is moved right before if-else to cover both cases Test Plan: checkout D85605372 buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.nvcc_arch=a100,h100 -c fbcode.split-dwarf=true -c fbcode.dwp=true -c fbcode.enable_distributed_thinlto=true -c fbcode.use_link_groups=true fbcode//inference_enablement/model_processing/infra/components/lowering/re:re_cinder -- -r "$(cat ./fbcode/minimal_viable_ai/umia_v1/ig/ss_omni_exp/re_lower_aoti.json)" with the diff, no issue was encountered. Reviewed By: tissue3 Differential Revision: D86474796 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167279 Approved by: https://github.com/pianpwk --- torch/_inductor/lowering.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 3e6ffd46f80f1..7946f9ae67ad8 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1300,8 +1300,11 @@ def compute_slice_index(index, size, default=None): V.graph.register_operation(b_size) new_size = sym_size - if start_index is not None: + if x.maybe_get_layout() is None: + # realize tensor before accessing layout x.realize() + + if start_index is not None: # we shouldn't have allocated storage offset symbol if start index was determinable assert sym_storage is None new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim] From c131e4b390ae779320c3a069e426b478fab46529 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 8 Nov 2025 01:33:12 +0000 Subject: [PATCH 247/651] Revert "[CP] Correctly compile create_cp_block_mask (#167153)" This reverts commit 5a9ae7cefe679ff925a0aa7b9f5782fc93d4ef29. Reverted https://github.com/pytorch/pytorch/pull/167153 on behalf of https://github.com/donigian due to breaking internal tests D86529123 ([comment](https://github.com/pytorch/pytorch/pull/167153#issuecomment-3505563239)) --- .../experimental/_context_parallel/_attention.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/torch/distributed/tensor/experimental/_context_parallel/_attention.py index b1903e211a1c1..09a86081df522 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_attention.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -1032,7 +1032,9 @@ def _disable_context_parallel_dispatcher_impl() -> None: _disable_cp_dtensor_dispatcher() -_compiled_create_block_mask = None +_compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True +) def _context_parallel_buffers( @@ -1185,12 +1187,9 @@ def _create_cp_block_mask( f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. " ) - global _compiled_create_block_mask - if _compiled_create_block_mask is None: - _compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True - ) - compiled_create_block_mask = _compiled_create_block_mask + compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True + ) def _rewrite_mask_mod( mask_mod: _mask_mod_signature, From ba5ffa2dcad8508a5d4bca414752f4effa71da04 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sat, 8 Nov 2025 01:43:09 +0000 Subject: [PATCH 248/651] [5/N] Use key in dict for existence checks (#167311) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167311 Approved by: https://github.com/janeyx99, https://github.com/Lucaskabela --- test/ao/sparsity/test_data_sparsifier.py | 2 +- test/ao/sparsity/test_sparsifier.py | 2 +- test/distributed/_tools/test_sac_ilp.py | 2 +- test/distributed/argparse_util_test.py | 2 +- .../checkpoint/_experimental/test_staging.py | 2 +- .../checkpoint/test_hf_safetensor_e2e.py | 8 ++++---- test/distributed/checkpoint/test_state_dict.py | 2 +- test/distributed/fsdp/test_fsdp_state_dict.py | 10 +++------- test/distributed/launcher/api_test.py | 2 +- test/distributed/launcher/test_run.py | 2 +- .../test_upgrader_models_generation.py | 2 +- test/jit/test_list_dict.py | 2 +- test/jit/test_module_containers.py | 4 ++-- test/jit/test_pdt.py | 2 +- test/jit/test_typing.py | 2 +- test/nn/test_load_state_dict.py | 2 +- test/onnx/test_onnx_opset.py | 2 +- test/profiler/test_profiler.py | 4 ++-- test/quantization/ao_migration/common.py | 2 +- test/quantization/fx/test_model_report_fx.py | 14 +++++++------- test/quantization/fx/test_numeric_suite_fx.py | 12 ++++++------ test/quantization/fx/test_quantize_fx.py | 4 ++-- .../pt2e/test_x86inductor_quantizer.py | 2 +- test/run_test.py | 4 ++-- torch/_dynamo/variables/builder.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/fuzzer.py | 2 +- 27 files changed, 47 insertions(+), 51 deletions(-) diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index fa08e8c90ac2f..46587833acb4d 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -208,7 +208,7 @@ def check_state_dict(self, data_list, data_with_config, defaults, **kwargs): assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups) state1 = state_dict1["state"] - for name in state1.keys(): + for name in state1: # compare mask assert name in sparsifier2.state assert "mask" in sparsifier2.state[name] diff --git a/test/ao/sparsity/test_sparsifier.py b/test/ao/sparsity/test_sparsifier.py index a940a3e9febab..776a36d029b54 100644 --- a/test/ao/sparsity/test_sparsifier.py +++ b/test/ao/sparsity/test_sparsifier.py @@ -119,7 +119,7 @@ def test_state_dict(self): for idx in range(len(sparsifier0.groups)): mg0 = sparsifier0.groups[idx] mg1 = sparsifier1.groups[idx] - for key in mg0.keys(): + for key in mg0: assert key in mg1 if key == "module": # We cannot compare modules as they are different diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index 9afb267ed1675..555b0efb9f89e 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -80,7 +80,7 @@ def _run_and_get_memTracker( # postprocessing due to the fact that for ModTracker, the post backward hook # is not being called for modules whose inputs don't require gradients # TODO: fix this in ModTracker and ensure it does not lead to any perf regression - if _ModState.POST_BW not in mod_stats.snapshots.keys(): + if _ModState.POST_BW not in mod_stats.snapshots: mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( copy.deepcopy(last_snapshot) ) diff --git a/test/distributed/argparse_util_test.py b/test/distributed/argparse_util_test.py index 1902faf992734..a3b3ef2bc717e 100644 --- a/test/distributed/argparse_util_test.py +++ b/test/distributed/argparse_util_test.py @@ -16,7 +16,7 @@ class ArgParseUtilTest(unittest.TestCase): def setUp(self): # remove any lingering environment variables - for e in os.environ.keys(): + for e in os.environ.keys(): # noqa: SIM118 if e.startswith("PET_"): del os.environ[e] diff --git a/test/distributed/checkpoint/_experimental/test_staging.py b/test/distributed/checkpoint/_experimental/test_staging.py index 5c4a1733fde03..c9be4fe43f49d 100644 --- a/test/distributed/checkpoint/_experimental/test_staging.py +++ b/test/distributed/checkpoint/_experimental/test_staging.py @@ -207,7 +207,7 @@ def test_multiple_staging_operations(self) -> None: for i, result in enumerate(staged_results): self.assertIsInstance(result, dict) # Verify the result contains the expected keys - for key in state_dicts[i].keys(): + for key in state_dicts[i]: self.assertIn(key, result) stager.close() diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index 1aaaf645c58df..b9979da8a97f1 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -60,7 +60,7 @@ def test_save(self) -> None: self.assertEqual( sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()) ) - for key in state_dict_to_save.keys(): + for key in state_dict_to_save: self.assertTrue( torch.equal(state_dict_to_save[key], state_dict_loaded[key]) ) @@ -89,7 +89,7 @@ def test_load(self) -> None: self.assertEqual( sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()) ) - for key in state_dict_to_save.keys(): + for key in state_dict_to_save: self.assertTrue( torch.equal(state_dict_to_save[key], state_dict_to_load[key]) ) @@ -116,7 +116,7 @@ def test_load_into_empty_dict(self) -> None: self.assertEqual( sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()) ) - for key in state_dict_to_save.keys(): + for key in state_dict_to_save: self.assertTrue( torch.equal(state_dict_to_save[key], state_dict_loaded[key]) ) @@ -156,7 +156,7 @@ def test_load_with_multiple_threads(self) -> None: self.assertEqual( sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()) ) - for key in state_dict_to_save.keys(): + for key in state_dict_to_save: self.assertTrue( torch.equal(state_dict_to_save[key], state_dict_to_load[key]) ) diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index 1206f13213108..03bcf7ce5e03e 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -769,7 +769,7 @@ def _test_deprecate_partial(self) -> None: model_state_dict3 = copy.deepcopy(model_state_dict3) self.assertEqual(len(model_state_dict2), 2) self.assertEqual(len(model_state_dict3), 2) - for key in model_state_dict3.keys(): + for key in model_state_dict3: full_fqn = f"l.{key}" value1 = model_state_dict1[full_fqn] value2 = model_state_dict2[full_fqn] diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index b0677655186a6..50e9e6a798681 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -587,9 +587,7 @@ def test_basic_save_and_load_state_dict( model, cpu_offload.offload_params, fp16 ) - ignore_keys = [ - k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k - ] + ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k] self._validate_state_dict_contents( model, @@ -910,7 +908,7 @@ def test_state_dict_load_into_local_module( with sd_mgr: fsdp_state_dict = model.state_dict() - ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k] + ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k] self._validate_state_dict_contents( model, fsdp_state_dict, @@ -959,9 +957,7 @@ def _create_module(wrap_fsdp=True): # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ - k - for k in dict(module.named_parameters()).keys() - if LINEAR_SKIP in k + k for k in dict(module.named_parameters()) if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index 48465516a913b..330fd302bbd45 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -137,7 +137,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables. - for env in os.environ.keys(): + for env in os.environ.keys(): # noqa: SIM118 if env.startswith("PET_"): del os.environ[env] diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index d271e60954ae7..50e2d53928c04 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -69,7 +69,7 @@ def setUp(self): self.test_dir = tempfile.mkdtemp() # remove any lingering environment variables - for env in os.environ.keys(): + for env in os.environ.keys(): # noqa: SIM118 if env.startswith("PET_"): del os.environ[env] diff --git a/test/jit/fixtures_srcs/test_upgrader_models_generation.py b/test/jit/fixtures_srcs/test_upgrader_models_generation.py index a23b95af9dfcf..028244ac89583 100644 --- a/test/jit/fixtures_srcs/test_upgrader_models_generation.py +++ b/test/jit/fixtures_srcs/test_upgrader_models_generation.py @@ -7,7 +7,7 @@ class TestUpgraderModelGeneration(TestCase): def test_all_modules(self): - for a_module in ALL_MODULES.keys(): + for a_module in ALL_MODULES: module_name = type(a_module).__name__ self.assertTrue( isinstance(a_module, torch.nn.Module), diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index b8853d2e6f5f4..90dbc30d5d790 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -2979,7 +2979,7 @@ def __init__(self) -> None: self.col2 = "b" def forward(self): - if self.col1 in self.segments_groupby_col.keys(): + if self.col1 in self.segments_groupby_col: return 1 else: return 2 diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index 7a8bbf58224bb..31254be34d671 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -78,7 +78,7 @@ def forward(self, x, skip_name): x = mod(x) values.append(x) - for key in self.moduledict.keys(): + for key in self.moduledict: names.append(key) return x, names @@ -306,7 +306,7 @@ def forward(self, inputs): assert "submod" in self.moduledict, "__contains__ fails for ModuleDict" - for key in self.moduledict.keys(): + for key in self.moduledict: assert key == "submod", "keys() fails for ModuleDict" for item in self.moduledict.items(): diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index 0ac620b368b6e..eaff742f55591 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -276,7 +276,7 @@ def test_substring(self, a, b): def test_multiple_class_with_same_method(self): class PDTModelOne: def test_find(self, a, b): - return b in a.keys() + return b in a class PDTModelTwo: def test_find(self, a, b): diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index c1a010dcfb94d..714fa6768958e 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -342,7 +342,7 @@ def test_dict_keys_values(x): # type: (Dict[str, int]) -> Tuple[str, int] key_str = "" sum = 0 - for key in x.keys(): + for key in x: key_str += key for val in x.values(): sum += val diff --git a/test/nn/test_load_state_dict.py b/test/nn/test_load_state_dict.py index 3d20787ac4456..48d6c6b8009ee 100644 --- a/test/nn/test_load_state_dict.py +++ b/test/nn/test_load_state_dict.py @@ -310,7 +310,7 @@ def forward(self, input): # Make sure parameters and persistent buffers were assigned net_meta_state_dict = net_meta.state_dict(keep_vars=True) - for key in state_dict.keys(): + for key in state_dict: if key in net_meta._parameters: if keep_vars and not swap: # state_dict[key] is an nn.Parameter diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 16ca93dbfe2c5..f6e33fe599817 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -42,7 +42,7 @@ def check_onnx_opset_operator( attributes = ops[i]["attributes"] assert len(attributes) == len(graph.node[i].attribute) for j in range(len(attributes)): - for attribute_field in attributes[j].keys(): + for attribute_field in attributes[j]: assert attributes[j][attribute_field] == getattr( graph.node[i].attribute[j], attribute_field ) diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index fc128ba61907a..43216274f9271 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -910,7 +910,7 @@ def judge(expected_event_count, prof): for e in prof.function_events: if "#" in e.name: key = e.name - if key in expected_event_count.keys(): + if key in expected_event_count: actual_event_count[key] = ( actual_event_count.setdefault(key, 0) + 1 ) @@ -3094,7 +3094,7 @@ def test_profiler_pattern_matcher_json_report(self): report = json.load(f) # It is platform dependent whether the path will include "profiler/" - keys = [k for k in report.keys() if k.endswith("test_profiler.py")] + keys = [k for k in report if k.endswith("test_profiler.py")] self.assertEqual(len(keys), 1, f"{keys}") entry = report[keys[0]] diff --git a/test/quantization/ao_migration/common.py b/test/quantization/ao_migration/common.py index 5797b4bab1d44..acfc8065de846 100644 --- a/test/quantization/ao_migration/common.py +++ b/test/quantization/ao_migration/common.py @@ -46,7 +46,7 @@ def _test_dict_import( old_dict = getattr(old_location, dict_name) new_dict = getattr(new_location, dict_name) assert old_dict == new_dict, f"Dicts don't match: {dict_name}" - for key in new_dict.keys(): + for key in new_dict: assert old_dict[key] == new_dict[key], ( f"Dicts don't match: {dict_name} for key {key}" ) diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 78408c1b5a36d..adf1fee586723 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -205,7 +205,7 @@ def test_multi_linear_model_without_per_channel(self): self.assertEqual(len(per_channel_info), 2) # for each linear layer, should be supported but not used - for linear_key in per_channel_info.keys(): + for linear_key in per_channel_info: module_entry = per_channel_info[linear_key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) @@ -277,7 +277,7 @@ def forward(self, x): self.assertEqual(len(per_channel_info), 4) # for each layer, should be supported but not used - for key in per_channel_info.keys(): + for key in per_channel_info: module_entry = per_channel_info[key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) @@ -327,7 +327,7 @@ def test_sequential_model_format(self): self.assertEqual(len(per_channel_info), 4) # for each layer, should be supported but not used - for key in per_channel_info.keys(): + for key in per_channel_info: module_entry = per_channel_info[key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) @@ -371,7 +371,7 @@ def test_conv_sub_class_considered(self): self.assertEqual(len(per_channel_info), 4) # for each layer, should be supported but not used - for key in per_channel_info.keys(): + for key in per_channel_info: module_entry = per_channel_info[key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) @@ -415,7 +415,7 @@ def test_fusion_layer_in_sequential(self): self.assertEqual(len(per_channel_info), 4) # for each layer, should be supported but not used - for key in per_channel_info.keys(): + for key in per_channel_info: module_entry = per_channel_info[key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) self.assertEqual(module_entry["per_channel_quantization_used"], True) @@ -482,7 +482,7 @@ def forward(self, x): self.assertEqual(len(per_channel_info), 1) # for the one conv, it should still give advice to use different qconfig - for key in per_channel_info.keys(): + for key in per_channel_info: module_entry = per_channel_info[key] self.assertEqual(module_entry["per_channel_quantization_supported"], True) self.assertEqual(module_entry["per_channel_quantization_used"], False) @@ -974,7 +974,7 @@ def test_prepare_model_callibration(self): # there should be two entries self.assertEqual(len(model_report.get_observers_of_interest()), 2) for detector in test_detector_set: - self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest().keys()) + self.assertTrue(detector.get_detector_name() in model_report.get_observers_of_interest()) # get number of entries for this detector detector_obs_of_interest_fqns = model_report.get_observers_of_interest()[detector.get_detector_name()] diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index 2b8afe1c7c8d8..75e4ebffbdf42 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -1787,7 +1787,7 @@ def test_layer_names(self): # extract weights results = extract_weights('fp32', mp, 'int8', mq) mq_node_names = [node.name for node in mq.graph.nodes] - for layer_name in results.keys(): + for layer_name in results: self.assertTrue(layer_name in mq_node_names) # match activations @@ -1799,7 +1799,7 @@ def test_layer_names(self): mq_ns(data) results = extract_logger_info(mp_ns, mq_ns, OutputLogger, 'int8') mq_node_names = [node.name for node in mq_ns.graph.nodes] - for layer_name in results.keys(): + for layer_name in results: self.assertTrue(layer_name in mq_node_names) # match shadow activations @@ -1810,7 +1810,7 @@ def test_layer_names(self): results = extract_shadow_logger_info( mp_shadows_mq, OutputLogger, 'int8') mq_node_names = [node.name for node in mp_shadows_mq.graph.nodes] - for layer_name in results.keys(): + for layer_name in results: self.assertTrue(layer_name in mq_node_names) @skipIfNoFBGEMM @@ -1834,11 +1834,11 @@ def test_extend_logger_results_with_comparison(self): for layer_results in results.values(): assert 'sqnr_int8_vs_fp32' in \ - layer_results['weight']['int8'][0].keys() + layer_results['weight']['int8'][0] assert 'l2_error_int8_vs_fp32' in \ - layer_results['weight']['int8'][0].keys() + layer_results['weight']['int8'][0] assert 'cosine_similarity_int8_vs_fp32' in \ - layer_results['weight']['int8'][0].keys() + layer_results['weight']['int8'][0] @skipIfNoFBGEMM def test_int8_shadows_fp32_simple(self): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 9c0526fde6987..f2b3091b75d6c 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -3476,7 +3476,7 @@ def forward(self, x0): def test_non_traceable_module(self): class NonTraceable(torch.nn.Module): def forward(self, x): - for k in x.keys(): + for k in x: print(x[k]) return x @@ -5000,7 +5000,7 @@ def from_observed(cls, observed_lstm): self.assertTrue(all(arg.target == "dequantize" for arg in node.args)) # Match following quantize with the specific qparams and dtypes expected_scale, expected_zp, expected_dtype = node_name_to_expected_quantize_args[node.name] - for user in node.users.keys(): + for user in node.users: self.assertEqual(user.target, torch.quantize_per_tensor) if expected_scale is not None: self.assertEqual(getattr(cell, user.args[1].target), expected_scale) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 5b9aa34158b5e..41b2351997d47 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1548,7 +1548,7 @@ def _check_annotation(node): return annot._annotated, annot._is_output_of_quantized_pattern for node in gm.graph.nodes: - if node.target in expected_stat_dict.keys(): + if node.target in expected_stat_dict: annotated, is_quant_out = _check_annotation(node) expected_stat_dict[node.target]["annotated"] -= annotated expected_stat_dict[node.target]["is_quant_out"] -= is_quant_out diff --git a/test/run_test.py b/test/run_test.py index 764b20dc9adc2..2abf324ad43d6 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -808,8 +808,8 @@ def print_to_file(s): print_to_file("Retrying single test...") print_items = [] # do not continue printing them, massive waste of space - consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3] - flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3] + consistent_failures = [x[1:-1] for x in num_failures if num_failures[x] >= 3] + flaky_failures = [x[1:-1] for x in num_failures if 0 < num_failures[x] < 3] if len(flaky_failures) > 0: print_to_file( "The following tests failed and then succeeded when run in a new process" diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d54d586cd08fc..085927db57fe4 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -591,7 +591,7 @@ def wrap_mapping_proxy(self, value): if not all_const: unimplemented_v2( gb_type="non-const keys in mappingproxy", - context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", + context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", # noqa: SIM118 explanation="Dynamo expects mappingproxy keys to be constants.", hints=[ "Ensure your mappingproxy keys are constants (e.g. int, float, strings)", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 1d985d6aa35da..9e57c498abbf1 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1681,7 +1681,7 @@ def set( if config.aot_inductor.emit_multi_arch_kernel: bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv", "hsaco": ".hsaco"} - assert bin_type in bin_type_to_ext.keys(), ( + assert bin_type in bin_type_to_ext, ( "multi_arch_kernel_binary only supported in CUDA/XPU/ROCm" ) base_path, _ = os.path.splitext(bin_path) diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 9565c76b2dde4..152dce2026766 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -912,7 +912,7 @@ def visualize_results( assert len(results) > 0 input_set: OrderedSet[str] = OrderedSet({}) - for key in results.keys(): + for key in results.keys(): # noqa: SIM118 input_set.add(key[0]) input_set.add(key[1]) input_list = sorted(input_set) From a2f109dcc33cd228eafc0100a23022299c17b44e Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 13:34:47 -0800 Subject: [PATCH 249/651] [dynamo] rename unimplemented_v2 -> unimplemented (#167150) Also force the new `unimplemented`/old `unimplemented_v2` to explicitly specify the `gb_type`, `context`, `explanation`, and `hints` args. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167150 Approved by: https://github.com/mlazos, https://github.com/zou3519 --- test/dynamo/test_error_messages.py | 2 +- tools/dynamo/gb_id_mapping.py | 12 +-- tools/linter/adapters/gb_registry_linter.py | 4 +- tools/test/test_gb_registry_linter.py | 20 ++-- torch/_dynamo/codegen.py | 6 +- torch/_dynamo/comptime.py | 4 +- torch/_dynamo/convert_frame.py | 8 +- torch/_dynamo/exc.py | 19 +++- torch/_dynamo/graph_bytecode_inputs.py | 8 +- torch/_dynamo/guards.py | 2 +- torch/_dynamo/output_graph.py | 16 +-- torch/_dynamo/side_effects.py | 14 +-- torch/_dynamo/symbolic_convert.py | 94 ++++++++-------- torch/_dynamo/utils.py | 36 +++---- torch/_dynamo/variables/base.py | 18 ++-- torch/_dynamo/variables/builder.py | 42 ++++---- torch/_dynamo/variables/builtin.py | 68 ++++++------ torch/_dynamo/variables/constant.py | 4 +- torch/_dynamo/variables/ctx_manager.py | 10 +- torch/_dynamo/variables/dicts.py | 10 +- torch/_dynamo/variables/distributed.py | 8 +- torch/_dynamo/variables/functions.py | 24 ++--- torch/_dynamo/variables/higher_order_ops.py | 112 ++++++++++---------- torch/_dynamo/variables/iter.py | 14 +-- torch/_dynamo/variables/lists.py | 8 +- torch/_dynamo/variables/misc.py | 66 ++++++------ torch/_dynamo/variables/nn_module.py | 20 ++-- torch/_dynamo/variables/script_object.py | 8 +- torch/_dynamo/variables/streams.py | 4 +- torch/_dynamo/variables/tensor.py | 62 +++++------ torch/_dynamo/variables/torch.py | 70 ++++++------ torch/_dynamo/variables/torch_function.py | 10 +- torch/_dynamo/variables/user_defined.py | 28 ++--- torch/_subclasses/meta_utils.py | 8 +- 34 files changed, 424 insertions(+), 415 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 995c733716f1b..e8c53832568e7 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1050,7 +1050,7 @@ def gn(): msg = re.sub(r"line (\d+)", "line N", msg) msg = re.sub( r"""(?s)Traceback \(most recent call last\):.* - File "exc.py", line N, in unimplemented_v2 + File "exc.py", line N, in unimplemented raise Unsupported\(msg\)""", "\n", msg, diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index 541189eb66792..1333e6d28cf1b 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -115,7 +115,7 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any: return clean_string(param_source) -def find_unimplemented_v2_calls( +def find_unimplemented_calls( path: str, dynamo_dir: Optional[str] = None ) -> list[dict[str, Any]]: results = [] @@ -135,15 +135,15 @@ def find_unimplemented_v2_calls( for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): if node.name in ( - "unimplemented_v2", - "unimplemented_v2_with_warning", + "unimplemented", + "unimplemented_with_warning", ): continue if ( isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id - in ("unimplemented_v2", "unimplemented_v2_with_warning") + in ("unimplemented", "unimplemented_with_warning") ): info: dict[str, Any] = { "gb_type": None, @@ -180,7 +180,7 @@ def find_unimplemented_v2_calls( def create_registry(dynamo_dir: str, registry_path: str) -> None: - calls = find_unimplemented_v2_calls(dynamo_dir) + calls = find_unimplemented_calls(dynamo_dir) registry = {} gb_types = {} @@ -224,7 +224,7 @@ def main() -> None: "--dynamo_dir", type=str, default=default_dynamo_dir, - help="Directory to search for unimplemented_v2 calls.", + help="Directory to search for unimplemented calls.", ) parser.add_argument( diff --git a/tools/linter/adapters/gb_registry_linter.py b/tools/linter/adapters/gb_registry_linter.py index 508fe2f9d1471..ac6bfc3264d51 100644 --- a/tools/linter/adapters/gb_registry_linter.py +++ b/tools/linter/adapters/gb_registry_linter.py @@ -15,7 +15,7 @@ from tools.dynamo.gb_id_mapping import ( - find_unimplemented_v2_calls, + find_unimplemented_calls, load_registry, next_gb_id, ) @@ -50,7 +50,7 @@ def _collect_all_calls( gb_type_calls: dict[str, list[tuple[dict[str, Any], Path]]] = {} for py_file in dynamo_dir.rglob("*.py"): - for call in find_unimplemented_v2_calls(py_file, dynamo_dir): + for call in find_unimplemented_calls(py_file, dynamo_dir): gb_type = call["gb_type"] if gb_type not in gb_type_calls: gb_type_calls[gb_type] = [] diff --git a/tools/test/test_gb_registry_linter.py b/tools/test/test_gb_registry_linter.py index 10f4a701b2c37..837e5910a4abb 100644 --- a/tools/test/test_gb_registry_linter.py +++ b/tools/test/test_gb_registry_linter.py @@ -27,10 +27,10 @@ def setUp(self): json.dump({}, f) self.callsite_file = self.test_data_dir / "callsite_test.py" - callsite_content = """from torch._dynamo.exc import unimplemented_v2 + callsite_content = """from torch._dynamo.exc import unimplemented def test(self): - unimplemented_v2( + unimplemented( gb_type="testing", context="testing", explanation="testing", @@ -101,9 +101,9 @@ def test_case2_rename_gb_type(self): with open(self.registry_path, "w") as f: json.dump(registry_data, f, indent=2) - renamed_callsite_content = """from torch._dynamo.exc import unimplemented_v2 + renamed_callsite_content = """from torch._dynamo.exc import unimplemented def test(self): - unimplemented_v2(gb_type="renamed_testing", context="testing", explanation="testing", hints=["testing"]) + unimplemented(gb_type="renamed_testing", context="testing", explanation="testing", hints=["testing"]) """ with open(self.callsite_file, "w") as f: f.write(renamed_callsite_content) @@ -168,9 +168,9 @@ def test_case3_content_change(self): with open(self.registry_path, "w") as f: json.dump(registry_data, f, indent=2) - updated_callsite_content = """from torch._dynamo.exc import unimplemented_v2 + updated_callsite_content = """from torch._dynamo.exc import unimplemented def test(self): - unimplemented_v2(gb_type="testing", context="new_context", explanation="new_explanation", hints=["new_hint"]) + unimplemented(gb_type="testing", context="new_context", explanation="new_explanation", hints=["new_hint"]) """ with open(self.callsite_file, "w") as f: f.write(updated_callsite_content) @@ -255,9 +255,9 @@ def test_case5_new_gbid_on_full_change(self): with open(self.registry_path, "w") as f: json.dump(registry_data, f, indent=2) - new_callsite_content = """from torch._dynamo.exc import unimplemented_v2 + new_callsite_content = """from torch._dynamo.exc import unimplemented def test(self): - unimplemented_v2( + unimplemented( gb_type="completely_new_testing", context="completely_new_context", explanation="completely_new_explanation", @@ -330,11 +330,11 @@ def test_case6_dynamic_hints_from_variable(self): init_py.touch() - dynamic_hints_callsite = """from torch._dynamo.exc import unimplemented_v2 + dynamic_hints_callsite = """from torch._dynamo.exc import unimplemented from torch._dynamo import graph_break_hints def test(self): - unimplemented_v2( + unimplemented( gb_type="testing_with_graph_break_hints", context="testing_with_graph_break_hints", explanation="testing_with_graph_break_hints", diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index cf76243b98ddc..8c19cb8b61e27 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -38,7 +38,7 @@ create_rot_n, Instruction, ) -from .exc import IncorrectUsage, unimplemented_v2 +from .exc import IncorrectUsage, unimplemented from .source import AttrSource, ChainedSource, DictGetItemSource, Source from .utils import is_safe_constant, rot_n_helper from .variables.base import ValueMutationExisting, VariableTracker @@ -215,7 +215,7 @@ def __call__( try: self.call_reconstruct(source) except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="Reconstruction failure: source.reconstruct not implemented", context=str(source), explanation=f"Dynamo has no bytecode reconstruction implemented for {type(source)} variable {source}.", @@ -359,7 +359,7 @@ def gen_fn() -> None: try: self.call_reconstruct(value) except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="Reconstruction failure", context=str(value), explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 65690dc446a24..34eec572ce550 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -47,7 +47,7 @@ def my_model(x): from torch._subclasses.fake_tensor import FakeTensor from torch.fx.experimental.symbolic_shapes import free_symbols -from .exc import unimplemented_v2 +from .exc import unimplemented from .variables import CellVariable from .variables.constant import ConstantVariable from .variables.tensor import SymNodeVariable @@ -193,7 +193,7 @@ def graph_break(self, msg: str = "ComptimeContext.graph_break") -> None: """ Manually trigger a graph break """ - unimplemented_v2( + unimplemented( gb_type="ComptimeContext graph break", context=msg, explanation=f"Manually triggered ComptimeContext graph break with message {msg}.", diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 8cf4ab8954d5a..d46a6e5919353 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -114,7 +114,7 @@ SkipCodeRecursiveException, TorchRuntimeError, UncapturedHigherOrderOpError, - unimplemented_v2, + unimplemented, Unsupported, ) from .graph_bytecode_inputs import reset_user_object_tracking @@ -646,7 +646,7 @@ def __call__( return ConvertFrameReturn() if is_generator(code): - unimplemented_v2( + unimplemented( gb_type="Attempt to trace generator", context="", explanation="Generators cannot be compiled directly with `torch.compile`.", @@ -1241,7 +1241,7 @@ def transform( # We now have a new "last attempt", reset the clock last_attempt_start_time = time.time() if attempt > 100: - unimplemented_v2( + unimplemented( gb_type="Excessive RestartAnalysis() calls", context="", explanation="Dynamo attempted to trace the same frame 100+ times. " @@ -1576,7 +1576,7 @@ def format_func_info(code: CodeType) -> str: raise RecompileLimitExceeded(f"{limit_type} reached") else: # do not recursively skip frames - unimplemented_v2( + unimplemented( gb_type="Dynamo cache limit exceeded", context=f"Limit type: {limit_type}", explanation="Dynamo attempted to recompile the code object too many times, " diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index d252f7c2a3b36..483d9e57e7fde 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -450,9 +450,10 @@ def handle_observed_exception(tx: Any) -> None: ) -def unimplemented_v2_with_warning( +def unimplemented_with_warning( e: Exception, code: types.CodeType, + *, gb_type: str, context: str, explanation: str, @@ -475,7 +476,16 @@ def unimplemented_v2_with_warning( payload_fn=lambda: graph_break_msg, ) graph_breaks_log.debug("%s", graph_break_msg) - unimplemented_v2(gb_type, context, explanation, hints, from_exc=e, log_warning=True) + _unimplemented = unimplemented + # to prevent a graph break registry entry + _unimplemented( + gb_type=gb_type, + context=context, + explanation=explanation, + hints=hints, + from_exc=e, + log_warning=True, + ) def format_graph_break_message( @@ -553,13 +563,12 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]: _NOTHING = object() -# TODO replace old unimplemented later -def unimplemented_v2( +def unimplemented( + *, gb_type: str, context: str, explanation: str, hints: list[str], - *, from_exc: Any = _NOTHING, log_warning: bool = False, ) -> NoReturn: diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 16583b89201ec..d10b749ae1d63 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -59,9 +59,9 @@ def register_graph_created_object( try: index_to_external_object_weakref[index] = weakref.ref(example_value) except TypeError as e: - from .exc import unimplemented_v2 + from .exc import unimplemented - unimplemented_v2( + unimplemented( gb_type="Failed to make weakref to graph-created external object", context=f"user_object: {example_value}", explanation="Object does not allow us to make a weakref to it", @@ -79,9 +79,9 @@ def register_user_object(value: Any, source: Source) -> int: try: index_to_external_object_weakref[index] = weakref.ref(value) except TypeError as e: - from .exc import unimplemented_v2 + from .exc import unimplemented - unimplemented_v2( + unimplemented( gb_type="Failed to make weakref to User Object", context=f"user_object: {value}", explanation="Object does not allow us to make a weakref to it", diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 67995e93bab77..4e7d83357d88d 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2284,7 +2284,7 @@ def NN_MODULE(self, guard: Guard) -> None: # If guard_nn_modules is true, we will guard on the right set of guards self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type] else: - exc.unimplemented_v2( + exc.unimplemented( gb_type="Attempted to guard on uninitialized nn.Module", context="", explanation="Attempted to setup an NN_MODULE guard on uninitialized " diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 1c6661e53a777..e45fa5f25138d 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -98,8 +98,8 @@ BackendCompilerFailed, exceptions_allowed_to_be_fallback, SkipFrame, - unimplemented_v2, - unimplemented_v2_with_warning, + unimplemented, + unimplemented_with_warning, ) from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor from .graph_deduplication import apply_graph_deduplication @@ -762,7 +762,7 @@ def add_backward_state_hook( def get_backward_state_proxy(self) -> torch.fx.Proxy: if self.backward_state_proxy is None: if self.export: - unimplemented_v2( + unimplemented( gb_type="backward_state does not support export", context="", explanation="Compiled autograd doesn't work with `torch.export`.", @@ -2403,7 +2403,7 @@ def _call_user_compiler( raise BackendCompilerFailed( self.compiler_fn, e, inspect.currentframe() ).with_traceback(e.__traceback__) from None - unimplemented_v2_with_warning( + unimplemented_with_warning( e, self.root_tx.f_code, gb_type="Backend compiler exception", @@ -2806,7 +2806,7 @@ def encountered_compliant_op(target: torch._ops.OpOverload) -> None: def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None: output_graph.non_compliant_ops.add(target) if config.only_allow_pt2_compliant_ops: - unimplemented_v2( + unimplemented( gb_type="Encountered non-PT2-compliant op", context="", explanation=msg + " " + err_epilogue, @@ -2848,7 +2848,7 @@ def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> Non target._qualified_op_name, *args, **kwargs ) except RuntimeError as e: - unimplemented_v2( + unimplemented( gb_type="Error when attempting to resolve op packet", context="", explanation=str(e), @@ -3147,7 +3147,7 @@ def get_trace_call_log_str() -> str: elif kind == "call_module": if self.parent is not None: # TODO can remove once inline_inbuilt_nn_modules is always True - unimplemented_v2( + unimplemented( gb_type="Invoking an nn.Module inside a higher order operator", context=f"Higher order op name: {self.source_target}", explanation="This is not supported.", @@ -3181,7 +3181,7 @@ def get_trace_call_log_str() -> str: elif kind == "call_module": if self.parent is not None: # TODO can remove once inline_inbuilt_nn_modules is always True - unimplemented_v2( + unimplemented( gb_type="Invoking an nn.Module inside a HigherOrderOperator", context="", explanation="This is not supported.", diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 688a05f26ae64..95ebeeb7f0a6d 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -41,7 +41,7 @@ create_instruction, ) from .codegen import PyCodegen -from .exc import SideEffectsError, unimplemented_v2 +from .exc import SideEffectsError, unimplemented from .source import GlobalSource, LocalCellSource, Source, TempLocalSource from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( @@ -261,7 +261,7 @@ def check_allowed_side_effect(self, item: VariableTracker) -> bool: assert item.mutation_type is not None if not is_side_effect_safe(item.mutation_type): # TODO plumb HOP information here - unimplemented_v2( + unimplemented( gb_type="HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)", context="", explanation="This is not supported.", @@ -289,7 +289,7 @@ def load_attr( assert self.is_attribute_mutation(item) result = self.store_attr_mutations[item][name] if not deleted_ok and isinstance(result, variables.DeletedVariable): - unimplemented_v2( + unimplemented( gb_type="Attempted to read a deleted variable", context=f"item: {item}, name: {name}", explanation="", @@ -299,7 +299,7 @@ def load_attr( def store_cell(self, cellvar: VariableTracker, value: VariableTracker) -> None: if cellvar.is_immutable(): - unimplemented_v2( + unimplemented( gb_type="Write to immutable cell", context=f"cellvar: {cellvar}, value: {value}", explanation="Dynamo doesn't support writing to immutable/sourceless cell variables.", @@ -315,7 +315,7 @@ def load_cell(self, cellvar: VariableTracker) -> VariableTracker: return self.load_attr(cellvar, "cell_contents", check=False) if cellvar.pre_existing_contents: return cellvar.pre_existing_contents - unimplemented_v2( + unimplemented( gb_type="Read uninitialized cell", context=str(cellvar), explanation="Attempted to read a cell variable that has not been populated yet.", @@ -731,7 +731,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: cg.clear_tos() var.source = TempLocalSource(cg.tempvars[var]) elif isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented_v2( + unimplemented( gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", context="", explanation="We cannot reconstruct a torch.autograd.Function's context object.", @@ -889,7 +889,7 @@ def codegen_update_mutated(self, cg: PyCodegen) -> None: isinstance(var.maxlen, variables.ConstantVariable) and var.maxlen.value is None ): - unimplemented_v2( + unimplemented( gb_type="Side effect on existing deque with limited maxlen", context="", explanation="This is not supported.", diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index f7903b198bcc4..179f0ed067552 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -97,7 +97,7 @@ get_stack_above_dynamo, ResumePrologueTracingError, StepUnsupported, - unimplemented_v2, + unimplemented, Unsupported, ) from .funcname_cache import get_funcname @@ -657,7 +657,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: elif self.should_compile_partial_graph(): jump_graph_break(self, inst, value) else: - unimplemented_v2( + unimplemented( gb_type="Data-dependent assertion failed (cannot compile partial graph)", context=f"value: {value}", explanation="Dynamo has determined when encountering a data-dependent assert failure " @@ -696,7 +696,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) if not result: - unimplemented_v2( + unimplemented( gb_type="Assertion failed on symbolic shapes", context=str(sym_expr), explanation="", @@ -772,7 +772,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: self.push(value) self.jump(inst) else: - unimplemented_v2( + unimplemented( gb_type="Data-dependent branching with non-constant __bool__", context=f"method: {x}, result: {result}", explanation="Attempted to perform data-dependent branching on a user-defined " @@ -825,7 +825,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: self.push(value) self.jump(inst) else: - unimplemented_v2( + unimplemented( gb_type="Data-dependent branching", context=f"attempted to jump with {value}", explanation=_explanation, @@ -859,7 +859,7 @@ def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None: # We don't support graph break under GenericContextWrappingVariable, # If there is, we roll back to the checkpoint and fall back. excp.remove_from_stats() - unimplemented_v2( + unimplemented( gb_type="Graph break under GenericContextWrappingVariable", context=f"Active generic context managers: {self.active_generic_context_managers}", explanation="Attempted to graph break in an active context manager(s) that doesn't support graph breaking.", @@ -983,7 +983,7 @@ def __init__(cls: type, name: str, bases: Any, dct: Any) -> None: super().__init__(name, bases, dct) # type: ignore[misc] def _missing(opname: str, *args: Any) -> None: - unimplemented_v2( + unimplemented( gb_type="Missing bytecode handler", context=f"{opname} with args {args}", explanation=f"Dynamo does not know how to handle the bytecode instruction `{opname}`.", @@ -1337,7 +1337,7 @@ def step(self) -> bool: or self.is_tracing_resume_prologue ): if isinstance(e, StepUnsupported): - unimplemented_v2( + unimplemented( gb_type="cannot resume from torch._dynamo.step_unsupported()", context="", explanation="traced torch._dynamo.step_unsupported(), but Dynamo is instructed " @@ -1352,7 +1352,7 @@ def step(self) -> bool: if self.current_speculation is None: log.debug("empty checkpoint - cannot resume from graph break") if isinstance(e, StepUnsupported): - unimplemented_v2( + unimplemented( gb_type="torch._dynamo.step_unsupported() with empty checkpoint", context="", explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint " @@ -1709,7 +1709,7 @@ def LOAD_FAST(self, inst: Instruction) -> None: new_name = name.replace(".", "implicit") self.push(self.symbolic_locals[new_name]) except KeyError: - unimplemented_v2( + unimplemented( gb_type="Attempted to read undefined local variable (implicit)", context=f"LOAD_FAST {name}", explanation=f"Could not find an implicit local variable with name `{name}`", @@ -1719,7 +1719,7 @@ def LOAD_FAST(self, inst: Instruction) -> None: ], ) else: - unimplemented_v2( + unimplemented( gb_type="Attempted to read undefined local variable", context=f"LOAD_FAST {name}", explanation=f"Could not find a local variable with name `{name}`", @@ -1824,7 +1824,7 @@ def STORE_GLOBAL(self, inst: Instruction) -> None: source, self.symbolic_globals[name] ) if isinstance(value, RemovableHandleVariable): - unimplemented_v2( + unimplemented( gb_type="Storing Tensor hook handle in globals", context=name, explanation="This is not supported.", @@ -1920,7 +1920,7 @@ def IMPORT_NAME(self, inst: Instruction) -> None: globals=self.f_globals, ) except ImportError: - unimplemented_v2( + unimplemented( gb_type="Import failure", context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}", explanation="Failure when attempting to import.", @@ -1951,7 +1951,7 @@ def IMPORT_NAME(self, inst: Instruction) -> None: # pyrefly: ignore [unbound-name] self.push(PythonModuleVariable(value, source=source)) else: - unimplemented_v2( + unimplemented( gb_type="Bad import result", # pyrefly: ignore [unbound-name] context=typestr(value), @@ -2092,7 +2092,7 @@ def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: if self._isinstance_exception(val): observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined, union-attr] raise observed_exception_type(f"raised exception {val}") - unimplemented_v2( + unimplemented( gb_type="Failed to raise exception", context=str(exc), explanation="Attempted to raise a non-Exception type/value.", @@ -2132,7 +2132,7 @@ def CLEANUP_THROW(self, inst: Instruction) -> None: tos = self.stack[-1] assert isinstance(tos, ExceptionVariable) if tos.exc_type is StopIteration: - unimplemented_v2( + unimplemented( gb_type="CLEANUP_THROW with StopIteration", context="", explanation="Received StopIteration when handling generator.throw/close. This is not supported.", @@ -2218,7 +2218,7 @@ def bubble_exception_to_interpreter() -> None: curr_exc = self.exn_vt_stack.get_current_exception() dynamo_exc = exc.get_dynamo_observed_exception(curr_exc.python_type()) assert isinstance(raised_exception, dynamo_exc) # sanity check - unimplemented_v2( + unimplemented( gb_type="Observed exception", context=f"raised exception {curr_exc.python_type_name()}({curr_exc.args})", # type: ignore[union-attr] explanation=observed_exn_gb_explanation, @@ -2273,7 +2273,7 @@ def bubble_exception_to_interpreter() -> None: # instruction translator. self.stack.clear() if type(self) is InstructionTranslator: - unimplemented_v2( + unimplemented( gb_type="Observed exception (EXCEPT_HANDLER)", context=str(raised_exception), explanation=observed_exn_gb_explanation @@ -2411,7 +2411,7 @@ def check_if_exc_matches(self) -> bool: UserDefinedExceptionObjectVariable, ), ): - unimplemented_v2( + unimplemented( gb_type="Exception with bad expected type", context=str(expected_exc_types), explanation=f"`except ...` has unsupported type {expected_exc_types}.", @@ -2420,7 +2420,7 @@ def check_if_exc_matches(self) -> bool: if sys.version_info >= (3, 11): if not self._isinstance_exception(exc_instance): - unimplemented_v2( + unimplemented( gb_type="Caught non-Exception value", context=str(exc_instance), explanation=f"Except expects to receive an object of Exception type but received {exc_instance}.", @@ -2443,7 +2443,7 @@ def check_if_exc_matches(self) -> bool: UserDefinedExceptionClassVariable, ), ): - unimplemented_v2( + unimplemented( gb_type="Exception with non-type expectation", context=str(expected_type), explanation=f"`except ...` expects a non-type: {expected_type}.", @@ -2498,7 +2498,7 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None: kwargsvars = ConstDictVariable({}) argsvars = self.pop() else: - unimplemented_v2( + unimplemented( gb_type="Variadic function call with bad flags", context=f"flags: {inst.argval}", explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}", @@ -2536,7 +2536,7 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None: kwargsvars, ConstDictVariable, ): - unimplemented_v2( + unimplemented( gb_type="Variadic function call with bad args/kwargs type", # pyrefly: ignore [unbound-name] context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}", @@ -2652,7 +2652,7 @@ def STORE_ATTR(self, inst: Instruction) -> None: def store_attr_graph_break(self, inst: Instruction) -> None: if not self.should_compile_partial_graph(): - unimplemented_v2( + unimplemented( gb_type="Should not compile partial graph (STORE_ATTR)", context="", explanation="Dynamo has determined when encountering an unsupported " @@ -3236,7 +3236,7 @@ def BUILD_LIST(self, inst: Instruction) -> None: def BUILD_SET(self, inst: Instruction) -> None: if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: - unimplemented_v2( + unimplemented( gb_type="missing BUILD_SET handler", context="", explanation="Missing BUILD_SET bytecode handler (for testing purposes).", @@ -3253,7 +3253,7 @@ def BUILD_LIST_UNPACK(self, inst: Instruction, cls: type = ListVariable) -> None try: items.extend(seq.force_unpack_var_sequence(self)) except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="Failed to unpack object for BUILD_LIST_UNPACK", context=str(seq), explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK " @@ -3391,7 +3391,7 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None: elif seq.has_force_unpack_var_sequence(self): val = seq.force_unpack_var_sequence(self) else: - unimplemented_v2( + unimplemented( gb_type="Failed to unpack object for UNPACK_SEQUENCE", context=str(seq), explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode " @@ -3400,7 +3400,7 @@ def UNPACK_SEQUENCE(self, inst: Instruction) -> None: ) # pyrefly: ignore [unbound-name] if len(val) != inst.argval: - unimplemented_v2( + unimplemented( gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE", # pyrefly: ignore [unbound-name] context=f"expected length: {inst.argval}, actual: {len(val)}", @@ -3429,7 +3429,7 @@ def UNPACK_EX(self, inst: Instruction) -> None: for item in reversed(vals_prefix): self.push(item) else: - unimplemented_v2( + unimplemented( gb_type="Failed to unpack object for UNPACK_EX", context=str(seq), explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.", @@ -3439,7 +3439,7 @@ def UNPACK_EX(self, inst: Instruction) -> None: @break_graph_if_unsupported(push=0) def graph_break_on_leaf_function(self, inst: Instruction) -> None: if self.is_leaf_tracer: - unimplemented_v2( + unimplemented( gb_type="Forced graph break on leaf function", context="", explanation="Forced graph break for nested graph break testing purposes", @@ -3545,7 +3545,7 @@ def BUILD_STRING(self, inst: Instruction) -> None: format_string_parts.append(part.format_string) args.extend(part.sym_args) if set(kwargs.keys()) & set(part.sym_kwargs.keys()): - unimplemented_v2( + unimplemented( gb_type="BUILD_STRING key conflict", context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}", explanation="Failed to build format string due to key conflict", @@ -3553,7 +3553,7 @@ def BUILD_STRING(self, inst: Instruction) -> None: ) kwargs.update(part.sym_kwargs) else: - unimplemented_v2( + unimplemented( gb_type="BUILD_STRING type error", context=str(part), explanation="Format string part type is not correct - expected constant or format string.", @@ -3867,7 +3867,7 @@ def enter_ctx( @staticmethod def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn: - unimplemented_v2( + unimplemented( gb_type="Unsupported context manager", context=f"Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on {ctx}", explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.", @@ -3930,7 +3930,7 @@ def END_FOR(self, inst: Instruction) -> None: def LOAD_FAST_CHECK(self, inst: Instruction) -> None: if istype(self.symbolic_locals.get(inst.argval, None), NullVariable): - unimplemented_v2( + unimplemented( gb_type="LOAD_FAST_CHECK on uninitialized variable", context=inst.argval, explanation=f"Attempted to load uninitialized local variable {inst.argval}", @@ -3964,7 +3964,7 @@ def CALL_INTRINSIC_1(self, inst: Instruction) -> None: # INTRINSIC_LIST_TO_TUPLE self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) else: - unimplemented_v2( + unimplemented( gb_type="Missing CALL_INTRINSIC_1 handler", context=f"CALL_INTRINSIC_1 operand: {inst.argval}", explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.", @@ -4561,7 +4561,7 @@ def _throw_if_in_functorch(self) -> None: # if it reaches here, it means Dynamo failed to inline a functorch function f"- torch.func.{name}(fn) requires the function to be inlined by dynamo" ) - unimplemented_v2( + unimplemented( gb_type="Unsupported functorch tracing attempt", context="", explanation=msg, @@ -4669,7 +4669,7 @@ def inline_call(cls, parent: Any, func: Any, args: Any, kwargs: Any) -> Any: @staticmethod def check_inlineable(func: Any) -> trace_rules.SkipResult: if func.has_self(): - unimplemented_v2( + unimplemented( gb_type="Inline attempt with __self__", context=str(func), explanation="Attempted to inline a function with the `__self__` attribute. " @@ -4683,7 +4683,7 @@ def check_inlineable(func: Any) -> trace_rules.SkipResult: msg = inspect.getattr_static( func.get_function(), "_torchdynamo_disable_msg", None ) - unimplemented_v2( + unimplemented( gb_type="Skip inlining `torch.compiler.disable()`d function", context=str(func.get_function()), explanation=f"Skip inlining function {func.get_function()} since it was wrapped " @@ -4719,7 +4719,7 @@ def check_inlineable(func: Any) -> trace_rules.SkipResult: "More graph breaks may occur as a result of attempting to trace into the function.", "Please file an issue to PyTorch.", ] - unimplemented_v2( + unimplemented( gb_type="Attempted to inline function marked as skipped", context=f"qualname: {fn_qualname}, name: {func.get_name()}, " f"filename: `{func.get_filename()}`, skip reason: {result.reason}", @@ -4761,7 +4761,7 @@ def build_inline_tracer( if result is None: if isinstance(func, SkipFunctionVariable): - unimplemented_v2( + unimplemented( gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)", context=f"Attempted to inline a SkipFunctionVariable {func}", explanation=( @@ -4792,7 +4792,7 @@ def build_inline_tracer( for v in itertools.chain(sub_locals.values()): if not isinstance(v, VariableTracker): - unimplemented_v2( + unimplemented( gb_type="Encountered unconverted argument when attempting to inline", context=f"func: {func}, arg: {v}", explanation="An argument to an inlined function was not successfully converted to a VariableTracker.", @@ -4802,7 +4802,7 @@ def build_inline_tracer( if code.co_name in ("__setitem__", "__setattr__") and not ( args and isinstance(args[0], variables.UserDefinedObjectVariable) ): - unimplemented_v2( + unimplemented( gb_type="Unsupported __setitem__/__setattr__ inline attempt", context=f"code name: {code.co_name}, args: {args}", explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.", @@ -5025,7 +5025,7 @@ def create_call_resume_at( ) -> list[Instruction]: if config.nested_graph_breaks: return super().create_call_resume_at(inst, all_stack_locals_metadata) - unimplemented_v2( + unimplemented( gb_type="Graph break in inlined function", context="", explanation="Graph breaks in an inlined call are not supported.", @@ -5106,7 +5106,7 @@ def STORE_GLOBAL(self, inst: Instruction) -> None: else: value = self.pop() if isinstance(value, RemovableHandleVariable): - unimplemented_v2( + unimplemented( gb_type="Storing Tensor hook handle in globals (inline call)", context=inst.argval, explanation="This is not supported.", @@ -5173,7 +5173,7 @@ def YIELD_FROM(self, inst: Instruction) -> None: # lifted the `unimplemented("generator")` in frame conversion. This codepath handles # subgenerator and lines up with this line in Python 3.10 # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599 - unimplemented_v2( + unimplemented( gb_type="Unreachable sub-generator code", context="", explanation="Should only be encountered while implementing generator support.", @@ -5231,14 +5231,14 @@ def SEND(self, inst: Instruction) -> None: # lifted the `unimplemented("generator")` in frame conversion. This codepath handles # subgenerator and lines up with this line in Python 3.11 # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 - unimplemented_v2( + unimplemented( gb_type="Unreachable sub-generator code", context="", explanation="Should only be encountered while implementing generator support.", hints=[], ) else: - unimplemented_v2( + unimplemented( gb_type="SEND with bad type", context=f"TOS type: {typestr(tos)}", explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ca56d9785febe..f72795039786a 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1251,10 +1251,10 @@ def proxy_args_kwargs(args: Any, kwargs: Any) -> tuple[tuple[Any, ...], dict[str proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} return proxy_args, proxy_kwargs except NotImplementedError as e: - from .exc import unimplemented_v2 + from .exc import unimplemented from .variables.base import typestr - unimplemented_v2( + unimplemented( gb_type="Failed to convert args/kwargs to proxy", context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}", explanation="Missing `as_proxy()` implementation for some arg/kwarg.", @@ -2756,9 +2756,9 @@ def _get_fake_tensor(vt: VariableTracker) -> Any: fake_tensor = vt.as_proxy().node.meta.get("example_value") if not is_fake(fake_tensor): from . import graph_break_hints - from .exc import unimplemented_v2 + from .exc import unimplemented - unimplemented_v2( + unimplemented( gb_type="Cannot check Tensor object identity without its fake value", context=str(fake_tensor), explanation="TensorVariable is missing a fake example_value.", @@ -2929,11 +2929,11 @@ def wrap_fake_exception(fn: Callable[[], Any]) -> Any: try: return fn() except UnsupportedFakeTensorException as e: - from .exc import unimplemented_v2 + from .exc import unimplemented msg = f"Encountered exception ({e.reason}) during fake tensor propagation." log.warning(msg) - unimplemented_v2( + unimplemented( gb_type="Fake tensor propagation exception", context=str(e.reason), explanation=msg, @@ -3326,11 +3326,11 @@ def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> An if "example_value" in node.meta and is_fake(node.meta["example_value"]): return node.meta["example_value"] elif required: - from torch._dynamo.exc import unimplemented_v2 + from torch._dynamo.exc import unimplemented from . import graph_break_hints - unimplemented_v2( + unimplemented( gb_type="Missing FakeTensor example value", context=str(node), explanation=f"`FakeTensor` example value was required for {node} but not available.", @@ -3385,7 +3385,7 @@ def get_fake_value( from .exc import ( TorchRuntimeError, - unimplemented_v2, + unimplemented, Unsupported, UserError, UserErrorType, @@ -3479,7 +3479,7 @@ def get_fake_value( "Consider wrapping the operator into a PyTorch-understood custom operator " "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)", ] - unimplemented_v2( + unimplemented( gb_type="Data dependent operator", context=str(cause.func), explanation=f"Operator `{cause.func}` has a non-Tensor output " @@ -3490,7 +3490,7 @@ def get_fake_value( cause, torch._subclasses.fake_tensor.DynamicOutputShapeException ): if not torch._dynamo.config.capture_dynamic_output_shape_ops: - unimplemented_v2( + unimplemented( gb_type="Dynamic shape operator", context=str(cause.func), explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.", @@ -3500,7 +3500,7 @@ def get_fake_value( ], ) else: - unimplemented_v2( + unimplemented( gb_type="Dynamic shape operator (no meta kernel)", context=str(cause.func), explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes", @@ -3524,7 +3524,7 @@ def get_fake_value( f"module `{module}` and you may need to `import {module}`" f"({ctx}), otherwise " ) - unimplemented_v2( + unimplemented( gb_type="Operator does not support running with fake tensors", context=f"unsupported operator: {cause.func}", explanation="", @@ -3545,7 +3545,7 @@ def get_fake_value( elif isinstance(cause, ValueRangeError): raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e elif isinstance(cause, TypeError) and "argument" in str(cause): - unimplemented_v2( + unimplemented( gb_type="TypeError when making fake tensor call", context=f"TypeError {node.target}: {cause}", explanation="", @@ -3623,9 +3623,9 @@ def make_error_message(e: Any) -> str: return node.target(*args, **kwargs) # type: ignore[operator] elif op == "call_method": if not hasattr(args[0], node.target): # type: ignore[arg-type] - from .exc import unimplemented_v2 + from .exc import unimplemented - unimplemented_v2( + unimplemented( gb_type="Missing attribute when running call_method node", context="", explanation=make_error_message("attribute not defined"), @@ -3643,7 +3643,7 @@ def make_error_message(e: Any) -> str: except (NotImplementedError, UnsupportedFakeTensorException) as e: # NB: mimic how wrap_fake_exception does it - from .exc import unimplemented_v2 + from .exc import unimplemented hints = [] if isinstance(e, NotImplementedError): @@ -3651,7 +3651,7 @@ def make_error_message(e: Any) -> str: "If the op is a PyTorch op, please file an issue to PyTorch.", ] - unimplemented_v2( + unimplemented( gb_type="NotImplementedError/UnsupportedFakeTensorException when running FX node", context="", explanation=make_error_message(e), diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 0abf2cc91e784..896ae3b5e53d6 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -23,7 +23,7 @@ from .. import graph_break_hints, variables from ..current_scope_id import current_scope_id -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, Source from ..utils import cmp_name_to_op_mapping, istype @@ -90,7 +90,7 @@ def __init__(self, typ: SourceType) -> None: elif typ is SourceType.New: self.scope = current_scope_id() else: - unimplemented_v2( + unimplemented( gb_type="Unsupported SourceType", context=f"MutationType.__init__ {self} {typ}", explanation=f"Dynamo does not support the type `{typ}`", @@ -349,7 +349,7 @@ def guard_as_python_constant(self) -> Any: try: return self.as_python_constant() except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="Not a Python constant", context=f"guard_as_python_constant {self}", explanation=f"Failed to convert {self} into a Python constant.", @@ -444,7 +444,7 @@ def force_apply_to_var_sequence( fn(v) def inspect_parameter_names(self) -> list[str]: - unimplemented_v2( + unimplemented( gb_type="Unsupported inspect call", context=f"inspect_parameter_names {self}", explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", @@ -452,7 +452,7 @@ def inspect_parameter_names(self) -> list[str]: ) def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker": - unimplemented_v2( + unimplemented( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", @@ -468,7 +468,7 @@ def call_function( args: Sequence["VariableTracker"], kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": - unimplemented_v2( + unimplemented( gb_type="Unsupported function call", context=f"call_function {self} {args} {kwargs}", explanation=f"Dynamo does not know how to trace the function `{self.debug_repr()}`", @@ -514,7 +514,7 @@ def call_method( or tx.output.side_effects.has_pending_mutation(self) or tx.output.side_effects.has_pending_mutation(other) ): - unimplemented_v2( + unimplemented( gb_type="Builtin `operator.*` comparison with constant `self` failed", context=f"call_method {self} {name} {args} {kwargs}", explanation=f"Failed to compare {self} with {other}, " @@ -560,7 +560,7 @@ def call_method( "(2) fix any graph breaks in the function above the comprehension, (3) wrap the comprehension in a " "function, or (4) use Python 3.12+." ) - unimplemented_v2( + unimplemented( gb_type="Unsupported method call", context=f"call_method {self} {name} {args} {kwargs}", explanation=f"Dynamo does not know how to trace method `{name}` of class `{self.python_type_name()}`", @@ -583,7 +583,7 @@ def is_realized(self) -> bool: return True def next_variable(self, tx: Any) -> "VariableTracker": - unimplemented_v2( + unimplemented( gb_type="Unsupported next() call", context=f"next({self})", explanation=f"Dynamo does not know how to trace calling `next()` on variable `{self}`.", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 085927db57fe4..aa1d0f04d2040 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -87,7 +87,7 @@ from .. import config, graph_break_hints, mutation_guard, replay_record, trace_rules from ..device_interface import get_registered_device_interfaces -from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented_v2 +from ..exc import InternalTorchDynamoError, raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard, make_dupe_guard from ..pgo import ( auto_dynamic, @@ -567,7 +567,7 @@ def wrap_removable_handle(self, value): # Our current infra requires the hook to be registered and removed in # the same frame. So graph break. # Related test - PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py -k TestAutograd.test_hooks - unimplemented_v2( + unimplemented( gb_type="Attempted to represent unregistered RemovableHandle", context="", explanation="Dynamo attempted to build a representation of a torch.utils.hooks.RemovableHandle, " @@ -589,7 +589,7 @@ def wrap_mapping_proxy(self, value): all_const = all(ConstantVariable.is_literal(k) for k in value) if not all_const: - unimplemented_v2( + unimplemented( gb_type="non-const keys in mappingproxy", context=f"non-const keys: {[k for k in value.keys() if not ConstantVariable.is_literal(k)]}", # noqa: SIM118 explanation="Dynamo expects mappingproxy keys to be constants.", @@ -807,7 +807,7 @@ def build_key_value(i, k, v): return var elif istype(value, set): if any(isinstance(x, torch.Tensor) for x in value): - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap a set with tensors", context="Python set containing torch.Tensor elements", explanation=( @@ -888,7 +888,7 @@ def build_key_value(i, k, v): keywords_source = AttrSource(self.get_source(), "keywords") for k, v in value.keywords.items(): if not ConstantVariable.is_literal(k): - unimplemented_v2( + unimplemented( gb_type="functools.partial() with non-literal keyword", context=f"non-literal keyword: {k}", explanation="functools.partial() expects literal/string keywords", @@ -1039,7 +1039,7 @@ def build_key_value(i, k, v): return self.wrap_unspecialized_primitive(value) elif isinstance(value, HigherOrderOperator): if value is torch._higher_order_ops.invoke_subgraph: - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap torch._higher_order_ops.invoke_subgraph", context="", explanation="Directly using invoke_subgraph is not supported. Use nested_compile_region", @@ -1202,7 +1202,7 @@ def build_key_value(i, k, v): # this is automatically done by evaluating the guards once but this # will cause data-dependent error when we evaluate the outer unbacked symints. # The test case that triggers this graph break is test_cond_unbacked_symint_closure - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap unbacked SymInt", context="", explanation="Unbacked SymInt input is not supported yet.", @@ -1616,7 +1616,7 @@ def build_key_value(i, k, v): ) return DictKeySetVariable(items, source=self.source) else: - unimplemented_v2( + unimplemented( gb_type="non-const keys in dict_keys", context=f"non-const keys: {[k for k in value if not ConstantVariable.is_literal(k)]}", explanation="Dynamo expects dict_keys keys to be constants.", @@ -1665,7 +1665,7 @@ def wrap_user_defined(self, value: Any): def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): for item in value: if item is value: - unimplemented_v2( + unimplemented( gb_type="list elements are pointing to the list itself", context="", explanation="Dynamo does not support lists whose items reference to itself", @@ -1834,7 +1834,7 @@ def wrap_module(self, value: torch.nn.Module): from ..eval_frame import OptimizedModule if len(value.__dict__) == 0: - unimplemented_v2( + unimplemented( gb_type="Uninitialized nn.Module", context=typestr(value), explanation=f"Attempted to trace an uninitialized nn.Module of type {typestr(value)}.", @@ -1866,7 +1866,7 @@ def wrap_module(self, value: torch.nn.Module): isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) and not config.allow_rnn ): - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap RNN, GRU, or LSTM", context=str(value), explanation="Dynamo does not support RNN, GRU, or LSTM.", @@ -1880,7 +1880,7 @@ def wrap_module(self, value: torch.nn.Module): # we can't do this assert inside FSDP constructor, # since we don't know yet whether dynamo will be used if not getattr(value, "_fsdp_use_orig_params", False): - unimplemented_v2( + unimplemented( gb_type="FSDP with use_orig_params=False", context="", explanation="Dynamo only supports FSDP with use_orig_params=True", @@ -2145,7 +2145,7 @@ def wrap_tensor(self, value: torch.Tensor): and value.is_nested and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) ): - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap strided NestedTensor", context="", explanation="torch.compile does not support strided NestedTensor", @@ -2161,7 +2161,7 @@ def wrap_tensor(self, value: torch.Tensor): # A hot fix for sparse tensors + torch.compile. Support for # export + sparsity is being added but we need to create # SPARSE_TENSOR_GUARDS for guards to work properly. - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap sparse Tensor", context="", explanation="torch.compile does not support sparse Tensors", @@ -2173,7 +2173,7 @@ def wrap_tensor(self, value: torch.Tensor): and safe_grad(value) is not None and value.dtype != safe_grad(value).dtype ): - unimplemented_v2( + unimplemented( gb_type="dtype mismatch between tensor and its gradient", context=f"tensor dtype: {value.dtype}; grad dtype: {safe_grad(value).dtype}", explanation="Inconsistent dtype between tensor and its gradient. " @@ -2294,7 +2294,7 @@ def wrap_numpy_ndarray(self, value): tensor_value = clone_preserve_strides(tensor_value) except NotImplementedError as e: # failed to convert to tensor, graph break - unimplemented_v2( + unimplemented( gb_type="failed to convert numpy.ndarray to Tensor", context=str(value), explanation="Exception encountered when attempting to convert numpy.ndarray to Tensor", @@ -2673,7 +2673,7 @@ def _dataclasses_fields_lambda(obj): if isinstance(obj, UserDefinedObjectVariable): value = obj.value else: - unimplemented_v2( + unimplemented( gb_type="dataclass fields failure", context=f"obj: {obj}; variable type: {type(obj)}", explanation=f"Dataclass fields handling fails for {obj}. Expected it to be a user-defined object.", @@ -2901,7 +2901,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe if is_sparse_any(example_value) and ( not tx.export or not config.capture_sparse_compute ): - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap sparse Tensor with VariableTracker", context=str(example_value), explanation="torch.compile does not support sparse Tensors with VariableTracker", @@ -3108,7 +3108,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) else: - unimplemented_v2( + unimplemented( gb_type="torch.* op returned non-Tensor", context=f"example_value type: {typestr(example_value)}; op: {proxy.node.op}; target: {proxy.node.target}", explanation="torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output", @@ -3308,7 +3308,7 @@ def _automatic_dynamic( if e.is_nested and not isinstance( e, torch.nested._internal.nested_tensor.NestedTensor ): - unimplemented_v2( + unimplemented( gb_type="Encountered strided NestedTensor in automatic dynamic dim determination", context="", explanation="torch.compile does not support strided NestedTensor", @@ -3770,7 +3770,7 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker: ): proxy = tx.output.bound_symbols[value.node.expr] return SymNodeVariable.create(tx, proxy) - unimplemented_v2( + unimplemented( gb_type="Unexpected type in sourceless builder", context=f"{value_type.__module__}.{value_type.__qualname__}", explanation=f"SourcelessBuilder.create does not know how to wrap {value_type}", diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index e15eb83c72573..f1d43b6d48995 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -45,7 +45,7 @@ ObservedAttributeError, ObservedUserStopIteration, raise_observed_exception, - unimplemented_v2, + unimplemented, Unsupported, UserError, UserErrorType, @@ -1034,7 +1034,7 @@ def create_exception_class_object( and isinstance(x.value, str) for x in args ): - unimplemented_v2( + unimplemented( gb_type="assert with non-string message", context=str(args), explanation="Dynamo only supports asserts with string messages", @@ -1104,7 +1104,7 @@ def call_self_handler( self_handler, e, ) - unimplemented_v2( + unimplemented( gb_type="invalid call to builtin op handler", context=f"invalid args to {self_handler}: {args} {kwargs}", explanation=f"Encountered TypeError when trying to handle op {fn.__name__}", @@ -1145,7 +1145,7 @@ def constant_fold_handler( args=list(map(ConstantVariable.create, exc.args)), ) except AsPythonConstantNotImplementedError as exc: - unimplemented_v2( + unimplemented( gb_type="constant fold exception", context=f"attempted to run function {fn} with arguments {args}", explanation="Encountered exception when attempting to constant fold.", @@ -1172,7 +1172,7 @@ def constant_fold_handler( }, ) except AsPythonConstantNotImplementedError as exc: - unimplemented_v2( + unimplemented( gb_type="constant fold exception", context=f"attempted to run function {fn} with arguments {args}", explanation="Encountered exception when attempting to constant fold.", @@ -1191,9 +1191,9 @@ def constant_fold_handler( handlers.append(constant_fold_handler) - def call_unimplemented_v2(args: Sequence[VariableTracker]) -> None: + def call_unimplemented(args: Sequence[VariableTracker]) -> None: real_arg_types = [arg.python_type_name() for arg in args] - unimplemented_v2( + unimplemented( gb_type="Failed to trace builtin operator", context=f"builtin {fn.__name__} {arg_types} {has_kwargs}", explanation=f"Dynamo does not know how to trace builtin operator `{fn.__name__}` " @@ -1208,7 +1208,7 @@ def call_unimplemented_v2(args: Sequence[VariableTracker]) -> None: ) if len(handlers) == 0: - return lambda tx, args, kwargs: call_unimplemented_v2(args) + return lambda tx, args, kwargs: call_unimplemented(args) elif len(handlers) == 1: (handler,) = handlers @@ -1220,7 +1220,7 @@ def builtin_dispatch( rv = handler(tx, args, kwargs) if rv: return rv - call_unimplemented_v2(args) + call_unimplemented(args) return rv else: @@ -1235,14 +1235,14 @@ def builtin_dispatch( rv = fn(tx, args, kwargs) if rv: return rv - call_unimplemented_v2(args) + call_unimplemented(args) return rv return builtin_dispatch def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: if len(args) == 0: - unimplemented_v2( + unimplemented( gb_type="unimplemented builtin op vars() with no arguments", context=f"vars: {self} {args}", explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments", @@ -1394,7 +1394,7 @@ def _handle_insert_op_in_graph( return wrap_fx_proxy(tx, proxy) except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="unimplemented builtin op on tensor arguments", context=f"partial tensor op: {self} {args} {kwargs}", explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with tensor arguments", @@ -1622,7 +1622,7 @@ def call_str( # account for __repr__ functions when __str__ is absent str_method = arg.value.__repr__ else: - unimplemented_v2( + unimplemented( gb_type="failed to call str() on user defined object", context=str(arg), explanation="User defined object has no __str__ or __repr__ method", @@ -1639,7 +1639,7 @@ def call_str( return None # pyrefly: ignore [unbound-name] elif is_wrapper_or_member_descriptor(str_method): - unimplemented_v2( + unimplemented( gb_type="Attempted to a str() method implemented in C/C++", context="", explanation=f"{type(arg.value)} has a C/C++ based str method. This is not supported.", @@ -1819,7 +1819,7 @@ def call_index( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: if isinstance(arg, variables.TensorVariable): - unimplemented_v2( + unimplemented( gb_type="unsupported index(Tensor)", context="", explanation="Dynamo does not support tracing builtin index() on a Tensor", @@ -2044,7 +2044,7 @@ def call_cast( if len(args) == 2: return args[1] - unimplemented_v2( + unimplemented( gb_type="bad args to builtin cast()", context=f"got args {args} {kwargs}", explanation="Dynamo expects exactly 2 args to builtin cast().", @@ -2103,7 +2103,7 @@ def call_custom_dict_fromkeys( **kwargs: VariableTracker, ) -> VariableTracker: if user_cls not in {dict, OrderedDict, defaultdict}: - unimplemented_v2( + unimplemented( gb_type="Unsupported dict type for fromkeys()", context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " @@ -2167,7 +2167,7 @@ def call_custom_dict_fromkeys( mutation_type=ValueMutationNew(), ) - unimplemented_v2( + unimplemented( gb_type="failed to call dict.fromkeys()", context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " @@ -2301,7 +2301,7 @@ def call_isinstance( try: arg_type = arg.python_type() except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="builtin isinstance() cannot determine type of argument", context=f"isinstance({arg}, {isinstance_type_var})", explanation=f"Dynamo doesn't have a rule to determine the type of argument {arg}", @@ -2344,7 +2344,7 @@ def check_type(ty: Any) -> bool: if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, types.MemberDescriptorType ): - unimplemented_v2( + unimplemented( gb_type="isinstance() called on user defined object with C extensions", context=f"isinstance({arg}, {isinstance_type})", explanation="User-defined object with C extensions can have torch.Tensor " @@ -2412,7 +2412,7 @@ def call_issubclass( left_ty_py = left_ty.as_python_constant() right_ty_py = right_ty.as_python_constant() except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="issubclass() with non-constant arguments", context=f"issubclass({left_ty}, {right_ty})", explanation="issubclass() with non-constant arguments not supported.", @@ -2505,7 +2505,7 @@ def call_getattr( default: VariableTracker | None = None, ) -> VariableTracker | None: if not name_var.is_python_constant(): - unimplemented_v2( + unimplemented( gb_type="getattr() with non-constant name argument", context=f"getattr({obj}, {name_var}, {default})", explanation="getattr() with non-constant name argument is not supported", @@ -2533,7 +2533,7 @@ def call_getattr( and obj.is_state_mutated and tx.output.side_effects.has_pending_mutation(obj) ): - unimplemented_v2( + unimplemented( gb_type="getattr() on nn.Module with pending mutation", context=f"getattr({obj}, {name}, {default})", explanation="Intentionally graph breaking on getattr() on a nn.Module " @@ -2598,7 +2598,7 @@ def call_getattr( "assertWarns", ) ): - unimplemented_v2( + unimplemented( gb_type="Failed to trace unittest method", context=f"function: unittest.TestCase.{name}", explanation=f"Dynamo does not know how to trace unittest method `{name}` ", @@ -2614,7 +2614,7 @@ def call_getattr( and is_sparse_any(fake_val) and (not tx.export or not config.capture_sparse_compute) ): - unimplemented_v2( + unimplemented( gb_type="Attempted to wrap sparse Tensor", context="", explanation="torch.compile does not support sparse Tensors", @@ -2691,7 +2691,7 @@ def call_setattr( # Some special handling for tensor attributes. if name == "requires_grad": # TODO(voz): Make it work properly - unimplemented_v2( + unimplemented( gb_type="setattr() on Tensor.requires_grad", context=f"setattr({obj}, {name}, {val})", explanation="setattr() on Tensor.requires_grad not supported. " @@ -2703,7 +2703,7 @@ def call_setattr( # See comments on `test_set_data_on_scoped_tensor` for plans # to support this. if obj.source is None: - unimplemented_v2( + unimplemented( gb_type="Failed to mutate tensor data attribute", context=f"setattr({obj}, {name}, {val})", explanation="Dyanmo only supports mutating `.data`" @@ -2714,7 +2714,7 @@ def call_setattr( ], ) elif obj.dtype != val.dtype: # type: ignore[attr-defined] - unimplemented_v2( + unimplemented( gb_type="Failed to mutate tensor data attribute to different dtype", context=f"setattr({obj}, {name}, {val})", explanation="Dyanmo only supports mutating `.data`" @@ -2780,7 +2780,7 @@ def _lower_version_count_by_1(x: torch.Tensor) -> torch.Tensor: # Attribute like `torch.Tensor.real` has special setters we # don't yet support; it's not as simple adding an entry to # the side effect mapping. - unimplemented_v2( + unimplemented( gb_type="Failed to set tensor attribute", context=f"setattr({obj}, {name}, {val})", explanation="Dyanmo doesn't support setting these tensor attributes", @@ -2940,7 +2940,7 @@ def call_id( elif istype(args[0], variables.FunctoolsPartialVariable): return variables.ConstantVariable.create(id(args[0].fake_value)) else: - unimplemented_v2( + unimplemented( gb_type="id() with unsupported args", context=str(args), explanation=f"Dynamo doesn't know how to trace id() call with args {args}", @@ -2954,7 +2954,7 @@ def call_id( def call_deepcopy( self, tx: "InstructionTranslator", x: VariableTracker ) -> VariableTracker: - unimplemented_v2( + unimplemented( gb_type="copy.deepcopy()", context=f"copy.deepcopy({x})", explanation="Dynamo does not support copy.deepcopy()", @@ -2985,7 +2985,7 @@ def _comparison_with_tensor( return ConstantVariable.create(not is_result) if op not in supported_tensor_comparison_op_values: - unimplemented_v2( + unimplemented( gb_type="unsupported Tensor comparison op", context=f"{op.__name__}({left}, {right})", explanation=f"Dynamo does not support the comparison op {op.__name__} " @@ -3002,7 +3002,7 @@ def _comparison_with_tensor( torch.broadcast_shapes(left.size, right.size) except RuntimeError: # not broadcastable, can't be compared - unimplemented_v2( + unimplemented( gb_type="failed to broadcast when attempting Tensor comparison op", context=f"{op.__name__}({left}, {right})", explanation=f"Dynamo was unable to broad cast the arguments {left}, {right} " @@ -3027,7 +3027,7 @@ def _comparison_with_symnode( op = self.fn if op not in supported_tensor_comparison_op_values: - unimplemented_v2( + unimplemented( gb_type="unsupported SymNode comparison op", context=f"{op.__name__}({left}, {right})", explanation=f"Dynamo does not support the comparison op {op.__name__} " diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 73d53cbca402b..1e886c6ee7ad7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -14,7 +14,7 @@ from torch._dynamo.source import AttrSource, GetItemSource from .. import graph_break_hints, variables -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..utils import ( cmp_name_to_op_mapping, common_constant_types, @@ -292,7 +292,7 @@ def create( for member in list(cls_type): if member.value == value_vt.as_python_constant(): return cls(member, **options) - unimplemented_v2( + unimplemented( gb_type="Failed to construct Enum variable", context=f"value: {value_vt}, allowed enum values: {list(cls_type)}", explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) " diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index b019296d98fcd..59c4bd99e25b8 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -34,7 +34,7 @@ create_instruction, create_setup_with, ) -from ..exc import unimplemented_v2 +from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..source import AttrSource, GlobalStateSource from ..utils import _get_error_on_graph_break, _set_error_on_graph_break @@ -1089,7 +1089,7 @@ def fn_name(self) -> str: return "nullcontext" def reconstruct(self, cg: "PyCodegen") -> None: - unimplemented_v2( + unimplemented( gb_type="torch.profiler object escaped from compiled region", context=str(self), explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", @@ -1161,7 +1161,7 @@ def exit( ).call_function(tx, [self.tensors, self.prev_versions], {}) def reconstruct(self, codegen: "PyCodegen") -> None: - unimplemented_v2( + unimplemented( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), explanation=( @@ -1376,7 +1376,7 @@ def fn_name(self) -> str: return "annotate" def reconstruct_type(self, codegen: "PyCodegen") -> None: - unimplemented_v2( + unimplemented( gb_type="torch.fx.traceback.annotate escaped from compiled region", context=str(self), explanation="Dynamo doesn't support graph break on torch.fx.traceback.annotate.", @@ -1467,7 +1467,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" except NotImplementedError: type_str = str(type(self.ctx)) - unimplemented_v2( + unimplemented( gb_type="Attempted to reconstruct context manager's __enter__ method", context=str(self.ctx), explanation=f"Attempted to reconstruct context manager {type_str} while tracing `with ...:`", diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 1c3a7011d4cfc..24cd5007da37d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -30,7 +30,7 @@ from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function, create_instruction -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import is_constant_source, is_from_local_source from ..utils import ( @@ -377,7 +377,7 @@ def getitem_const( key = ConstDictVariable._HashableTracker(arg) if key not in self.items: msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] - unimplemented_v2( + unimplemented( gb_type="key not found in dict", context=f"Key {arg.value}", # type: ignore[attr-defined] explanation=msg, @@ -819,7 +819,7 @@ def call_obj_hasattr( return ConstantVariable.create(False) msg = f"hasattr on {self.user_cls} is not supported" - unimplemented_v2( + unimplemented( gb_type="unsupported hasattr operation", context=f"Class {self.user_cls}", explanation=msg, @@ -854,7 +854,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed " "because the connection to the original dict will be lost." ) - unimplemented_v2( + unimplemented( gb_type="mapping proxy cannot be reconstructed", context=f"Source: {self.source}", explanation=msg, @@ -892,7 +892,7 @@ def call_method( "are trying to access a proxy object." ) - unimplemented_v2( + unimplemented( gb_type="mapping proxy affected by dictionary mutation", context=f"Source: {self.source}, Dict mutation detected", explanation=msg, diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 187055c26cd00..f6faf4414d1da 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -29,7 +29,7 @@ from .. import compiled_autograd, variables from .._trace_wrapped_higher_order_op import trace_wrapped from ..bytecode_transformation import create_call_function -from ..exc import unimplemented_v2 +from ..exc import unimplemented from ..external_utils import call_module_hooks_from_backward_state from ..guards import GuardBuilder, install_guard from ..source import AttrSource @@ -57,7 +57,7 @@ class DistributedVariable(VariableTracker): def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) if not DistributedVariable.is_available(): - unimplemented_v2( + unimplemented( gb_type="torch.distributed package is not available!", context="", explanation="The PyTorch package doesn't include torch.distributed when building from source.", @@ -212,7 +212,7 @@ def call_method( try: value_type = type(self.value) if inspect.getattr_static(value_type, "__getattr__", None) is not None: - unimplemented_v2( + unimplemented( gb_type="Placement with custom __getattr__ not supported", context=f"{value_type.__name__} with custom __getattr__", explanation="Dynamo does not support Placement types with custom __getattr__ methods", @@ -394,7 +394,7 @@ def create( user_pre_hooks: VariableTracker, ) -> "BackwardHookVariable": if not compiled_autograd.compiled_autograd_enabled: - unimplemented_v2( + unimplemented( gb_type="Module-level backwards hooks require compiled autograd.", context="", explanation="", diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 2f64c825a07fc..8411441724d3c 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -51,7 +51,7 @@ raise_observed_exception, SkipFrame, StepUnsupported, - unimplemented_v2, + unimplemented, Unsupported, ) from ..guards import GuardBuilder, install_guard @@ -422,7 +422,7 @@ def __init__( # TODO putting this here to avoid duplication, because we could hit this # from several paths (e.g., SuperVariable or `var_getattr`s). if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)): - unimplemented_v2( + unimplemented( gb_type="can't handle functions not implemented in python ", context=f"{fn}", explanation="Dynamo can only handle functions defined in python", @@ -583,7 +583,7 @@ def call_function( if not isinstance(fn_var, BaseUserFunctionVariable): typ = fn_var.python_type() msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>" - unimplemented_v2( + unimplemented( gb_type="TypeError from user code", context=f"call_function({self.value}, {args}, {kwargs})", # type: ignore[attr-defined] explanation=msg, @@ -595,7 +595,7 @@ def call_function( if not isinstance(fn_var, UserFunctionVariable): fn_name = fn_var.get_name() msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 - unimplemented_v2( + unimplemented( gb_type="Limitation of `nonstrict_trace", context=f"{self}", explanation=msg, @@ -1066,7 +1066,7 @@ def call_function( kwargs: dict[str, VariableTracker], ) -> VariableTracker: if not is_generator(self.vt.get_code()): # type: ignore[attr-defined] - unimplemented_v2( + unimplemented( gb_type="non-generator contextlib.contextmanager", context=str(self.vt.get_code()), # type: ignore[attr-defined] explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator" @@ -1617,7 +1617,7 @@ def call_function( ) -> VariableTracker: if inspect.getattr_static(self.value, "_torchdynamo_disable", False): msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) - unimplemented_v2( + unimplemented( gb_type="Skip calling `torch.compiler.disable()`d function", context=str(self.value), explanation=f"Skip calling function `{self.value}` since it was wrapped " @@ -1630,7 +1630,7 @@ def call_function( graph_break_msg = kwargs.get("msg") if graph_break_msg: graph_break_msg = graph_break_msg.as_python_constant() - unimplemented_v2( + unimplemented( gb_type="Call to `torch._dynamo.graph_break()`", context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`", explanation=f"User-inserted graph break. Message: {graph_break_msg}", @@ -1724,7 +1724,7 @@ def call_function( ) hints = [] reason = self.reason if self.reason else "" - unimplemented_v2( + unimplemented( gb_type="Attempted to call function marked as skipped", context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}", explanation=explanation, @@ -1950,7 +1950,7 @@ def call_function( args = () if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): - unimplemented_v2( + unimplemented( gb_type="async_op=True for distributed collectives", context=f"{self.fn}, {args=}, {kwargs=}", explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}", @@ -1990,7 +1990,7 @@ def call_function( def wraps(fn: Any) -> VariableTracker: if isinstance(fn, variables.NestedUserFunctionVariable): return fn.clone(wrapped_fn=args[0]) - unimplemented_v2( + unimplemented( gb_type="functools.wraps", context=f"{fn}", explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region", @@ -2032,7 +2032,7 @@ def call_function( value, mutation_type=ValueMutationNew(), ) - unimplemented_v2( + unimplemented( gb_type="namedtuple construction", context=f"{args=}, {kwargs=}", explanation="`torch.compile` only support certain input types for namedtuple", @@ -2338,7 +2338,7 @@ def check_grid(self, grid: "BaseListVariable") -> tuple[torch.fx.proxy.Proxy, .. if isinstance(grid, BaseListVariable): return grid.as_proxy() else: - unimplemented_v2( + unimplemented( gb_type="unsupported grid type for triton hop check_grid", context=f"grid type = {type(grid)}", explanation="`torch.compile` only supports list-like grid for check_grid", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index dc7a7d13908a8..b713b02c4e41a 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -50,7 +50,7 @@ from ..exc import ( ObservedException, UncapturedHigherOrderOpError, - unimplemented_v2, + unimplemented, Unsupported, ) from ..source import AttrSource, DictGetItemSource @@ -161,7 +161,7 @@ def _unwrap_var(var): elif isinstance(var, ConstantVariable): return var.as_python_constant() else: - unimplemented_v2( + unimplemented( gb_type="cannot unwrap variable for check_meta_consistency", context=str(var), explanation=f"Expected {var} to be TensorVariable, SymNodeVariable, or ConstantVariable", @@ -313,7 +313,7 @@ def _check_all_tensorvariable(args): from . import TensorVariable if not all(type(a.realize()) is TensorVariable for a in args): - unimplemented_v2( + unimplemented( gb_type="HOP: non torch.Tensor leaf", context=f"args types: {[type(a.realize()) for a in args]}", explanation="Expected all leaves to be of torch.Tensor type.", @@ -328,7 +328,7 @@ def _check_supported_callable_arg( BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() ) if not is_callable: - unimplemented_v2( + unimplemented( gb_type="HOP: non-callable variable", context=f"arg name: {arg_name}, func_var type: {str(func_var)}", explanation=f"{arg_name} should be a callable but is of type {str(func_var)}.", @@ -359,7 +359,7 @@ def _call_while_loop( args.append(v) if kwargs or len(args) != 4: - unimplemented_v2( + unimplemented( gb_type="torch.while_loop: improper args/kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"torch.while_loop expects 4 positional arguments (got {len(args)}) " @@ -379,7 +379,7 @@ def _call_while_loop( # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.while_loop: improper additional_inputs", context=str(additional_inputs), explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", @@ -484,7 +484,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: cond_r.proxy.node.meta["example_value"], include_contiguity=False ) if cond_r_meta.dtype != torch.bool or cond_r_meta.shape != torch.Size([]): - unimplemented_v2( + unimplemented( gb_type="torch.while_loop: unsupported cond_fn return type", context=str(cond_r), explanation=f"Expected cond_fn to return a scalar tensor or a bool but got {cond_r_meta.shape}.", @@ -496,7 +496,7 @@ def unspecialize_carried_inputs(tx, carry) -> VariableTracker: # short-circuiting while_loop when cond_fn returns a constant such as 0, 1 True or False pred = cond_r.as_python_constant() if pred: - unimplemented_v2( + unimplemented( gb_type="torch.while_loop: infinite loop detected", context=str(cond_r), explanation=f"Infinite loop detected because while_loop's cond_fn always returns the same value {pred}.", @@ -811,7 +811,7 @@ def validate_args_and_maybe_create_graph_inputs( # If `a` cannot be put into a graph else: # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). - unimplemented_v2( + unimplemented( gb_type="HOP body taking non-Tensor as input", context=str(sub_args), explanation=f"{description} with body that accepts non-Tensors as input. " @@ -974,7 +974,7 @@ def speculate_subgraph( # See NOTE [Temporary argument `set_subgraph_inputs`] if sub_kwargs and set_subgraph_inputs != "automatic": - unimplemented_v2( + unimplemented( gb_type="invalid set_subgraph_inputs and sub_kwargs settings", context=f"set_subgraph_inputs: {set_subgraph_inputs}, sub_kwargs: {sub_kwargs}", explanation="`sub_kwargs` cannot be used when `set_subgraph_inputs` is not set to 'automatic'.", @@ -1190,7 +1190,7 @@ def move_lifted_freevars_phs_to_end( mutation_info = subtracer.has_input_mutation() if mutation_info.has_mutation: context = f"{mutation_info.msg} in\n {graph}" - unimplemented_v2( + unimplemented( gb_type="Encountered input mutation during higher order op tracing", context=context, explanation=f"Higher order ops do not support input mutation. Found in {source_target.name()}", @@ -1204,7 +1204,7 @@ def move_lifted_freevars_phs_to_end( aliasing_info = subtracer.has_aliasing() if aliasing_info.has_aliasing: context = f"{aliasing_info.msg} in\n {graph}" - unimplemented_v2( + unimplemented( gb_type="Encountered aliasing during higher order op tracing", context=context, explanation=f"Higher order ops do not support aliasing. Found in {source_target.name()}", @@ -1269,7 +1269,7 @@ def make(value, source=None, **kwargs): if isinstance(value, BaseHOP): return BaseHOPVariable(value, source, **kwargs) - unimplemented_v2( + unimplemented( gb_type="unsupported HigherOrderOperator", context=str(value), explanation=f"Unable to create higher order operator variable for {value.__name__}.", @@ -1297,7 +1297,7 @@ def _call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - unimplemented_v2( + unimplemented( gb_type="unsupported HigherOrderOperator function call", context=str(self.value), explanation=f"Unable to trace calling higher order operator variable for {self.value.__name__}.", @@ -1357,7 +1357,7 @@ def _call_function( # TODO(voz): Support fake tensor dispatch for recursive # ops - see torch/dispatch/_dispatcher.py if len(args) != 4 or kwargs: - unimplemented_v2( + unimplemented( gb_type="torch.cond: improper args/kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) " @@ -1383,7 +1383,7 @@ def _call_function( # predicate if type(pred) not in (ConstantVariable, TensorVariable, SymNodeVariable): - unimplemented_v2( + unimplemented( gb_type="torch.cond: improper predicate", context=str(pred), explanation="Expected `pred` to be a bool or a boolean tensor with a single item " @@ -1395,7 +1395,7 @@ def _call_function( # operands if not isinstance(operands, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.cond: improper operands", context=str(operands), explanation="Expected `operands` to be a list/tuple " @@ -1409,7 +1409,7 @@ def _call_function( if not only_consist_of( operands, (TensorVariable, ConstantVariable, SymNodeVariable) ): - unimplemented_v2( + unimplemented( gb_type="torch.cond: improper operands contents", context=str(operands), explanation="Expected `operands` to be a list/tuple of pytrees that only consists of tensor leaves.", @@ -1463,7 +1463,7 @@ def speculate_branch(branch): tx.fake_mode.epoch += 1 if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.cond: unsupported branch return type", context=str(ret_val), explanation="Expected branches to return a possibly nested pytree of tensors or constant ints.", @@ -1473,7 +1473,7 @@ def speculate_branch(branch): ) for ret in ret_val.unpack_var_sequence(tx): if isinstance(ret, ConstantVariable) and ret.python_type() is not int: - unimplemented_v2( + unimplemented( gb_type="torch.cond: unsupported branch return type (constant non-int)", context=str(ret_val), explanation="Constants returned from branches must be ints.", @@ -1499,7 +1499,7 @@ def speculate_branch(branch): ).as_python_constant() # 3.14: NotImplemented cannot be converted to bool if same_spec is not NotImplemented and not same_spec: - unimplemented_v2( + unimplemented( gb_type="torch.cond: differing branch outputs", context=f"true_spec: {true_spec.treespec}, false_spec: {false_spec.treespec}, same_spec: {same_spec}", explanation="Expected branches to return the same pytree structure.", @@ -1602,7 +1602,7 @@ def validate_subgraph_output_types(output: VariableTracker): isinstance(out, ConstantVariable) and out.python_type() in (int, bool) ): continue - unimplemented_v2( + unimplemented( gb_type="HOP body output unsupported", context=f"non-tensor outputs: {non_tensor_output}", explanation="HigherOrderOperator body's output must consist of tensors or ints/bools only " @@ -1671,7 +1671,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): # This is the standard case when the user calls the frontend # and the frontend invokes dynamo if len(args) != 2: - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: improper args", context=f"args: {args}", explanation=f"torch.associative_scan expects 2 positional arguments (got {len(args)}) " @@ -1697,7 +1697,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): # xs input check if not isinstance(xs, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: improper xs", context=str(xs), explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", @@ -1710,7 +1710,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: improper additional_inputs", context=str(additional_inputs), explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", @@ -1723,7 +1723,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] if scan_length == 0: - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: zero-sized tensor", context=str(xs_vars[0]), explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", @@ -1776,7 +1776,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): # Check whether the combine_fn returns one child tree for the output. if _combine_treespec.as_python_constant().num_leaves < 1: - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: combine_fn improper number of leaves", context=str(_combine_treespec.as_python_constant()), explanation="combine_fn needs to produce one pytree for the output " @@ -1795,7 +1795,7 @@ def arg_extractor(combine_fn, xs, additional_inputs): ) or not _make_inlined(tx, pytree.TreeSpec.__eq__)( xs_treespec, _combine_treespec ).as_python_constant(): - unimplemented_v2( + unimplemented( gb_type="torch.associative_scan: mismatched input/output tree structure", context=f"xs: {xs_treespec.as_python_constant()}, output: {_combine_treespec.as_python_constant()}", explanation="The tree structure of the xs and the outs of the combine_fn are are expected to be identical, but got " @@ -1907,7 +1907,7 @@ def _check_combine_fn_is_normalized(combine_fn_var): variables.FunctoolsPartialVariable, ), ): - unimplemented_v2( + unimplemented( gb_type="torch.scan: improper combine_fn", context=str(combine_fn_var), explanation="Expected combine_fn to be wrapped as functools.partial in scan user-facing api " @@ -1948,7 +1948,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) # xs input check if not isinstance(xs, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.scan: improper xs", context=str(xs), explanation=f"Expected xs to be a list/tuple but got {xs.python_type()}", @@ -1958,7 +1958,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) # init input check if not isinstance(init, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.scan: improper init", context=str(init), explanation=f"Expected init to be a list/tuple with at least one element but got {init.python_type()}", @@ -1968,7 +1968,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) if len(init_vars) == 0: - unimplemented_v2( + unimplemented( gb_type="torch.scan: no init leaves", context="", explanation="Expected init leaves.", @@ -1979,7 +1979,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): # additional_inputs input check if not isinstance(additional_inputs, (ListVariable, TupleVariable)): - unimplemented_v2( + unimplemented( gb_type="torch.scan: improper additional_inputs", context=str(additional_inputs), explanation=f"Expected additional_inputs to be a list/tuple but got {additional_inputs.python_type()}", @@ -1990,7 +1990,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): # scan_length check scan_length = get_fake_value(xs_vars[0].as_proxy().node, tx).size()[0] if scan_length == 0: - unimplemented_v2( + unimplemented( gb_type="torch.scan: zero-sized tensor", context=str(xs_vars[0]), explanation="associative_scan() operator doesn't support zero-sized tensors during tracing.", @@ -2047,7 +2047,7 @@ def arg_extractor(combine_fn, init, xs, additional_inputs): ) else: if len(combine_result_vars) != 2: - unimplemented_v2( + unimplemented( gb_type="torch.scan: improper combine_fn number of returns", context=str(combine_result_vars), explanation=f"Expect combine_fn to return a tuple (next_carry, y) but got {combine_result_vars}.", @@ -2143,7 +2143,7 @@ def _call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if len(kwargs) > 0: - unimplemented_v2( + unimplemented( gb_type="torch.map: kwargs not supported", context=f"args: {args}, kwargs: {kwargs}", explanation=f"torch.map expects no keyword arguments (got {len(kwargs)})", @@ -2163,7 +2163,7 @@ def _call_function( sample_shape = get_fake_value(unpacked_xs[0].as_proxy().node, tx).size() if len(sample_shape) < 1 or sample_shape[0] == 0: - unimplemented_v2( + unimplemented( gb_type="torch.map: improper inputs", context=str(sample_shape), explanation="torch.map doesn't support scalar or non-zero sized tensors during tracing.", @@ -2257,7 +2257,7 @@ def _call_function( # executorch_call_delegate sits at a higher level than dynamo, but # there's no real solution to this issue yet. if len(kwargs) > 0: - unimplemented_v2( + unimplemented( gb_type="executorch_call_delegate: kwargs not supported", context=f"args: {args}, kwargs: {kwargs}", explanation=f"executorch_call_delegate expects no keyword arguments (got {len(kwargs)})", @@ -2317,7 +2317,7 @@ def call_function( self, tx, args: list[VariableTracker], kwargs: dict[str, VariableTracker] ) -> VariableTracker: if not torch._dynamo.config.inline_inbuilt_nn_modules: - unimplemented_v2( + unimplemented( gb_type="torch.func.functional_call capture is disabled", context="", explanation="torch.func.functional_call capture is disabled", @@ -2427,7 +2427,7 @@ def _call_function( ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") if len(p_kwargs) > 0: - unimplemented_v2( + unimplemented( gb_type="WrapHigherOrderVariable: kwargs unexpected", context=f"args: {args}, kwargs: {kwargs}", explanation="kwargs should have been flattened into lifted args.", @@ -2468,7 +2468,7 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if kwargs: - unimplemented_v2( + unimplemented( gb_type="wrap_with_set_grad_enabled: unexpected kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"wrap_with_set_grad_enabled expects no keyword arguments (got {len(kwargs)}).", @@ -2480,7 +2480,7 @@ def call_function( grad_enabled, fn_var, *rest_args = args if not isinstance(grad_enabled, ConstantVariable): - unimplemented_v2( + unimplemented( gb_type="wrap_with_set_grad_enabled: non-constant grad_enabled", context=str(grad_enabled), explanation="wrap_with_set_grad_enabled expects grad_enabled argument to be a constant.", @@ -2508,7 +2508,7 @@ def call_function( ) if len(body_lifted_freevars) > 0: - unimplemented_v2( + unimplemented( gb_type="wrap_with_set_grad_enabled: unexpected freevars", context=str(body_lifted_freevars), explanation="wrap_with_set_grad_enabled expects no freevars.", @@ -2555,7 +2555,7 @@ def call_function( args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) if kwargs: - unimplemented_v2( + unimplemented( gb_type="wrap_with_autocast: unexpected kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"wrap_with_autocast expects no keyword arguments (got {len(kwargs)}).", @@ -2568,7 +2568,7 @@ def call_function( for arg in [device_type, dtype, enabled, cache_enabled]: if not isinstance(arg, ConstantVariable): - unimplemented_v2( + unimplemented( gb_type="wrap_with_autocast: expected constant arg", context=str(args), explanation="wrap_with_autocast expects device_type, dtype, enabled, " @@ -2602,7 +2602,7 @@ def call_function( ) if len(body_lifted_freevars) > 0: - unimplemented_v2( + unimplemented( gb_type="wrap_with_autocast: unexpected freevars", context=str(body_lifted_freevars), explanation="wrap_with_autocast expects no freevars.", @@ -2652,7 +2652,7 @@ def _call_function( or len(kwargs) != 1 or "hints" not in kwargs ): - unimplemented_v2( + unimplemented( gb_type="hints_wrapper: improper args/kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"hints_wrapper expects 3 positional arguments (got {len(args)}) " @@ -2718,7 +2718,7 @@ def _call_function( from .builder import wrap_fx_proxy if len(kwargs) > 0: - unimplemented_v2( + unimplemented( gb_type="out_dtype: unexpected kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"out_dtype expects no keyword arguments (got {len(kwargs)}).", @@ -2764,7 +2764,7 @@ def _call_function( # TODO (tmanlaibaatar) support pytree here for arg in unpacked_sequence: if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): - unimplemented_v2( + unimplemented( gb_type="strict_mode: improper args", context=f"args: {args}, kwargs: {kwargs}", explanation="strict_mode higher order op expects flat inputs (list/tuple/dict)", @@ -2774,7 +2774,7 @@ def _call_function( ) if kwargs: - unimplemented_v2( + unimplemented( gb_type="strict_mode: unexpected kwargs", context=f"args: {args}, kwargs: {kwargs}", explanation=f"strict_mode higher order op expects no keyword arguments (got {len(kwargs)}).", @@ -3301,7 +3301,7 @@ def bwd(ctx, grad, x): ) fwd_args = [fwd_fn.obj, ctx, *args] else: - unimplemented_v2( + unimplemented( gb_type="autograd.Function.apply: non-function or method forward", context=str(self.fwd_graph), explanation="Expected forward function to be a function or method.", @@ -3326,7 +3326,7 @@ def bwd(ctx, grad, x): "_materialize_non_diff_grads" in tx.output.side_effects.store_attr_mutations[ctx] ): - unimplemented_v2( + unimplemented( gb_type="autograd.Function.apply: _materialize_non_diff_grads mutation", context="", explanation="Mutations to autograd.Function.ctx._materialize_non_diff_grads are not supported.", @@ -3361,7 +3361,7 @@ def bwd(ctx, grad, x): ) bwd_args = [bwd_fn.obj, *bwd_args] else: - unimplemented_v2( + unimplemented( gb_type="autograd.Function.apply: non-function or method backward", context=str(self.bwd_graph), explanation="Expected backward function to be a function or method.", @@ -3417,7 +3417,7 @@ def is_strict_for(v: VariableTracker): UserDefinedClassVariable(self.bwd_graph.__class__), ) else: - unimplemented_v2( + unimplemented( gb_type="autograd.Function.apply: non-function or method backward (2)", context=str(self.bwd_graph), explanation="Expected backward function to be a function or method.", @@ -3708,7 +3708,7 @@ def install_subgraph_in_output_graph( # using the saved attr name. if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): - unimplemented_v2( + unimplemented( gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", context=str(fn_vt), explanation="invoke_subgraph does not support non user function variable", @@ -3780,7 +3780,7 @@ def _call_function( ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "invoke_subgraph") if len(p_kwargs) > 0: - unimplemented_v2( + unimplemented( gb_type="invoke_subgraph: kwargs unexpected", context=f"args: {args}, kwargs: {kwargs}", explanation="kwargs should have been flattened into lifted args.", diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 624844382d53a..162ec02a9a9b7 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -28,7 +28,7 @@ handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, - unimplemented_v2, + unimplemented, UserError, ) from .base import ValueMutationNew, VariableTracker @@ -64,7 +64,7 @@ def call_function( if self.value is itertools.product: if any(kw != "repeat" for kw in kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported kwargs for itertools.product", context=f"call_function {self} {args} {kwargs}", explanation=f"Expected kwargs: 'repeat', but got " @@ -104,7 +104,7 @@ def call_function( ) elif self.value is itertools.groupby: if any(kw != "key" for kw in kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported kwargs for itertools.groupby", context=f"call_function {self} {args} {kwargs}", explanation=f"Expected kwargs: 'key', but got " @@ -118,7 +118,7 @@ def retrieve_const_key(key: VariableTracker) -> Any: elif isinstance(key, variables.ConstantVariable): return key.as_python_constant() else: - unimplemented_v2( + unimplemented( gb_type="Unsupported key type for itertools.groupby", context=f"call_function {self} {args} {kwargs}", explanation="Dynamo does not know how to trace " @@ -130,7 +130,7 @@ def retrieve_const_key(key: VariableTracker) -> Any: if len(args) == 1 and args[0].has_unpack_var_sequence(tx): seq = args[0].unpack_var_sequence(tx) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported arguments for itertools.groupby", context=f"call_function {self} {args} {kwargs}", explanation="Dynamo does not know how to trace " @@ -175,7 +175,7 @@ def keyfunc(x: VariableTracker) -> Any: ) ) except Exception as e: - unimplemented_v2( + unimplemented( gb_type="Unexpected failure during itertools.groupby() iteration", context=f"call_function {self} {args} {kwargs}", explanation="Unexpected failure in invoking function during groupby", @@ -227,7 +227,7 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: - unimplemented_v2( + unimplemented( gb_type="Unimplemented next() call", context=f"next({self})", explanation="This abstract method must be implemented", diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index d6c005bccbda3..3c525312198c8 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -30,7 +30,7 @@ class that handles its unique behaviors while integrating with Dynamo's create_instruction, create_rot_n, ) -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..source import AttrSource, NamedTupleFieldsSource from ..utils import ( cmp_name_to_op_mapping, @@ -162,7 +162,7 @@ def call_method( if value.constant is not None and value.constant.numel() == 1: value = variables.ConstantVariable.create(value.constant.item()) else: - unimplemented_v2( + unimplemented( gb_type="Indexing list with non-scalar tensor", context=f"call_method {self} {name} {args} {kwargs}", explanation=( @@ -878,7 +878,7 @@ def call_method( except NotImplementedError: python_type = "unknown" - unimplemented_v2( + unimplemented( gb_type="sort with non-constant keys", context=str(first_non_constant_key), explanation=( @@ -1607,7 +1607,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return variables.GetAttrVariable(self, name) fields = ["start", "stop", "step"] if name not in fields: - unimplemented_v2( + unimplemented( gb_type="Unsupported attribute for slice() object", context=f"var_getattr {self} {name}", explanation=f"Expected attribute to be one of {','.join(fields)} " diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7942e2fbd7bfa..099498dcf14f9 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -39,7 +39,7 @@ create_instruction, ) from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import ( @@ -108,7 +108,7 @@ def reconstruct(self, codegen: "PyCodegen"): def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): if not self.objvar: - unimplemented_v2( + unimplemented( gb_type="1-arg super not implemented", context="", explanation=f"Dynamo failed to trace attribute `{name}` accessed " @@ -159,7 +159,7 @@ def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name): ) return resolved_getattr, source - unimplemented_v2( + unimplemented( gb_type="Unable to resolve super getattr", context="", explanation=f"Dynamo failed to trace attribute `{name}` accessed " @@ -220,7 +220,7 @@ def call_method( ) return fn_vt.call_function(tx, [self.objvar] + args, kwargs) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported super().__init__() call", context=f"call_method {self} {name} {args} {kwargs}", explanation="Dynamo encountered a super().__init__() call " @@ -290,7 +290,7 @@ def call_method( try: attr = attr.as_python_constant() except NotImplementedError as exc: - unimplemented_v2( + unimplemented( gb_type="Non-constant attribute given to `super().__delattr__()`", context=f"call_method {self} {name}", explanation="Dynamo requires the attribute name passed to " @@ -301,7 +301,7 @@ def call_method( from_exc=exc, ) if not tx.output.side_effects.is_attribute_mutation(self.objvar): - unimplemented_v2( + unimplemented( gb_type="Attempted super().__delattr__() on an object without mutation tracking", context=f"call_method {self} {name}", explanation="Dynamo needs to track mutations on an object " @@ -392,7 +392,7 @@ def call_method( fn_var = VariableTracker.build(tx, inner_fn, source) return fn_var.call_function(tx, [self.objvar] + args, kwargs) - unimplemented_v2( + unimplemented( gb_type="Attempted to call a super() attribute that is " "not a function or method", context=f"call_method {self} {name}", @@ -414,7 +414,7 @@ def __init__( self.exc_type = exc_type self.args = args if init_kwargs: - unimplemented_v2( + unimplemented( gb_type="Keyword args passed to exception constructor", context=f"{self} with kwargs {init_kwargs}", explanation="Dynamo does not know how to handle keyword args passed to an exception constructor", @@ -495,7 +495,7 @@ def raise_error(msg): if isinstance(val, ConstantVariable) and val.value is None: self.__traceback__ = val else: - unimplemented_v2( + unimplemented( gb_type="Set Exception object `__traceback__` attribute to not-`None`", context=f"call_setattr {self} {name}", explanation="Dynamo does not support setting the attribute " @@ -507,7 +507,7 @@ def raise_error(msg): ], ) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported attribute assignment on Exception object", context=f"call_setattr {self} {name}", explanation="Dynamo does not support setting the attribute " @@ -567,7 +567,7 @@ def call_function( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - unimplemented_v2( + unimplemented( gb_type="Unsupported function call (delayed)", context=f"source: {self.source}", explanation="Dynamo determined that a graph break should occur " @@ -722,7 +722,7 @@ def visit(vt): vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined] if vjp_fn is not torch.autograd.Function.vjp: - unimplemented_v2( + unimplemented( gb_type="Unsupported custom vjp", context=f"call_apply {self} {args} {kwargs}", explanation="Dynamo does not support tracing " @@ -737,7 +737,7 @@ def visit(vt): jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined] if jvp_fn is not torch.autograd.Function.jvp: - unimplemented_v2( + unimplemented( gb_type="Unsupported custom jvp", context=f"call_apply {self} {args} {kwargs}", explanation="Dynamo does not support tracing " @@ -798,7 +798,7 @@ def visit(vt): source=source, ).call_function(tx, args, kwargs) else: - unimplemented_v2( + unimplemented( gb_type="Non-function or method in subclass of torch.autograd.Function", context=f"call_apply {self} {args} {kwargs}", explanation="Dynamo requires the `forward` attribute of a " @@ -873,7 +873,7 @@ def call_method( obj.__func__, self, source=source ).call_function(tx, args, kwargs) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported autograd.Function method", context=f"call_method {self} {name}", explanation="Dynamo does not support calling the method " @@ -943,7 +943,7 @@ def create(tx: "InstructionTranslator", args=None, kwargs=None): def as_proxy(self): if self.proxy is None: - unimplemented_v2( + unimplemented( gb_type="proxy not set", context=f"as_proxy {self}", explanation="Dynamo requires the autograd.Function context " @@ -968,7 +968,7 @@ def call_method( return variables.ConstantVariable.create(None) if name != "save_for_backward": - unimplemented_v2( + unimplemented( gb_type="Unsupported autograd.Function context method", context=f"call_method {self} {name}", explanation="Dynamo does not support calling the method " @@ -978,7 +978,7 @@ def call_method( hints=[*graph_break_hints.SUPPORTABLE], ) if self.saved_tensors is None: - unimplemented_v2( + unimplemented( gb_type="Unsupported autograd.Function context `save_for_backward`", context=f"call_method {self} {name}", explanation="Dynamo requires the `saved_tensors` attribute " @@ -1057,7 +1057,7 @@ def call_method( kwargs, ) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()", context=f"call_method {self} {name}", explanation="queue_callback() is only supported when " @@ -1065,7 +1065,7 @@ def call_method( hints=[], ) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported torch._C._ImperativeEngine method", context=f"call_method {self} {name}", explanation="Dynamo only supports the `queue_callback` method " @@ -1283,7 +1283,7 @@ def call_function( except AsPythonConstantNotImplementedError: pass - unimplemented_v2( + unimplemented( gb_type="unsupported type.__dict__['__annotations__'].__get__ call", context=f"call_function {self}, args: {args}, kwargs: {kwargs}", explanation="`torch.compile` only supports calling type.__dict__['__annotations__'].__get__ " @@ -1382,7 +1382,7 @@ def call_method( if name == "__getitem__" and len(args) == 1: new_typing = self.value[args[0].as_python_constant()] return TypingVariable(new_typing) - unimplemented_v2( + unimplemented( gb_type="unsupported method call on `typing` variable", context=f"typing variable: {self.value}, method name: {name}, args: {args}, kwargs: {kwargs}", explanation=f"`torch.compile` does not support method call `{name}` on `typing` variable f{self.value}.", @@ -1501,7 +1501,7 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if not config.trace_numpy: - unimplemented_v2( + unimplemented( gb_type="attempted to trace numpy function with config.trace_numpy=False", context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", explanation=f"Attempted to trace numpy function {self.value} " @@ -1516,7 +1516,7 @@ def call_function( func = get_np_to_tnp_map().get(self.value) if func is None: - unimplemented_v2( + unimplemented( gb_type="attempted to trace numpy function unsupported by PyTorch", context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", explanation=f"Can't find numpy numpy function {self.value} in torch._numpy.", @@ -1537,7 +1537,7 @@ def call_function( ) ) except AsPythonConstantNotImplementedError: - unimplemented_v2( + unimplemented( gb_type="numpy function that produces a const collection type encountered non-const arguments", context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", explanation=f"numpy function {self.value} that produces a const collection type " @@ -1552,7 +1552,7 @@ def call_function( func.__module__ == "torch._numpy.random" and config.use_numpy_random_stream ): - unimplemented_v2( + unimplemented( gb_type="attempted to trace torch._numpy.random function with config.use_numpy_random_stream=True", context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs} (corresponding torch function: {func})", explanation=f"Attempted to trace {self.value} when `torch._dynamo.config.use_numpy_random_stream` " @@ -1591,7 +1591,7 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - unimplemented_v2( + unimplemented( gb_type="attempted to trace numpy.* function as a method", context=f"numpy function: {self.value}, args: {args}, kwargs: {kwargs}", explanation="Tracing numpy.* functions as methods is not supported.", @@ -1623,7 +1623,7 @@ def __repr__(self) -> str: def reconstruct(self, codegen: "PyCodegen"): if sys.version_info < (3, 11): - unimplemented_v2( + unimplemented( gb_type="cannot reconstruct NullVariable in Python < 3.11", context="", explanation="Attempted to generate PUSH_NULL instruction in Python < 3.11; " @@ -1712,7 +1712,7 @@ def call_function(self, tx: "InstructionTranslator", args, kwargs): return if not self.can_reorder_logs(self.value, args, kwargs): - unimplemented_v2( + unimplemented( gb_type="attempted to reorder a debugging function that can't actually be reordered", context=f"fn: {self.value}, args: {args}, kwargs: {kwargs}", explanation="`torch.compile` can only reorder functions where the arguments " @@ -1771,7 +1771,7 @@ def call_method( function = getattr(method, "__func__", None) if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): return variables.ConstantVariable.create(None) - unimplemented_v2( + unimplemented( gb_type="logging.Logger method not supported for non-export cases", context=f"method: {self.value}.{name}, args: {args}, kwargs: {kwargs}", explanation="logging.Logger methods are not supported for non-export cases.", @@ -1814,7 +1814,7 @@ def call_method( cargs = [x.as_python_constant() for x in args] ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()} except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="constant-like method call with non-constant args", context=f"{self._error_prefix}.{name}(*{args}, **{kwargs})", explanation=f"Attempted to call {self._error_prefix}.{name} with non-constant args.", @@ -1830,7 +1830,7 @@ def call_method( if isinstance(result, re.Match): return ConstantRegexMatchVariable(result) - unimplemented_v2( + unimplemented( gb_type="constant-like method call with unsupported return type", context=f"{self._error_prefix}.{name}(*{args}, **{kwargs}) returned {result}", explanation=f"Attempted to call {self._error_prefix}.{name}, got unsupported return value {result}.", @@ -1901,7 +1901,7 @@ def __init__(self, **kwargs) -> None: def call_function(self, tx: "InstructionTranslator", args, kwargs): if len(args) > 1 or kwargs: - unimplemented_v2( + unimplemented( gb_type="random.Random() with improper arguments", context=f"args: {args}, kwargs: {kwargs}", explanation="random.Random() with > 1 arg or with kwargs is not supported.", diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f6ba0b1a5ffbc..b58580cd61240 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -36,7 +36,7 @@ from .. import graph_break_hints, trace_rules, variables from ..exc import ( raise_observed_exception, - unimplemented_v2, + unimplemented, UnspecializeRestartAnalysis, Unsupported, ) @@ -263,7 +263,7 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): base = tx.output.get_submodule(self.module_key) if object_has_getattribute(base): - unimplemented_v2( + unimplemented( gb_type="Custom __getattribute__ in nn.Module dict key check", context=f"has_key_in_generic_dict {self} {key}", explanation="Dynamo does not support checking key existence " @@ -285,7 +285,7 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): def _custom_getattr_fallback(self, base, tx, name, obj_source): """Check for a __getattr__ and handle it specially if it is implemented""" if object_has_getattribute(base): - unimplemented_v2( + unimplemented( gb_type="Custom __getattribute__ in nn.Module attribute access", context=f"var_getattr {self} {name}", explanation="Dynamo does not support checking key existence " @@ -302,7 +302,7 @@ def _custom_getattr_fallback(self, base, tx, name, obj_source): return None if not isinstance(getattr_fn, types.FunctionType): - unimplemented_v2( + unimplemented( gb_type="torch.nn.Module with a non-function custom __getattr__", context=f"var_getattr {self} {name}", explanation=( @@ -336,7 +336,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): all_class_attribute_names.update(x.__dict__.keys()) if not self.source: - unimplemented_v2( + unimplemented( gb_type="getattr with no source", context=f"var_getattr {self} {name}", explanation="Dynamo does not know how to access an attribute " @@ -423,7 +423,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): # Support possibly common cases of class members return VariableTracker.build(tx, subobj, NNModuleSource(source)) else: - unimplemented_v2( + unimplemented( gb_type="Unsupported nn.Module attribute type", context=f"nn.Module subclass: {typestr(base)}, name: {name}, attribute type: {typestr(subobj)}", explanation=f"Dynamo does not support tracing nn.Module attributes of type `{typestr(subobj)}`", @@ -644,7 +644,7 @@ def assert_all_args_kwargs_const(): if not all( x.is_python_constant() for x in itertools.chain(args, kwargs.values()) ): - unimplemented_v2( + unimplemented( gb_type="non-const argument in nn.Module method", context=f"call_method: {self} {name} {args} {kwargs}", explanation="Dynamo does not support calling " @@ -830,7 +830,7 @@ def gen_source(source, name): isinstance(args[0], variables.ConstantVariable) and isinstance(args[0].as_python_constant(), (str, int)) ): - unimplemented_v2( + unimplemented( gb_type="Invalid or non-const argument in nn.Module __getitem__", context=f"call_method: {self} {name} {args} {kwargs}", explanation="Dynamo does not support calling " @@ -893,7 +893,7 @@ def gen_source(source, name): elif args[0].is_python_constant(): key = args[0].as_python_constant() else: - unimplemented_v2( + unimplemented( gb_type="Unsupported key type for nn.Module.__getitem__", context=f"call_method: {self} {name} {args} {kwargs}", explanation="Dynamo does not support getitem on " @@ -1136,7 +1136,7 @@ def call_method( hasattr(method, "__code__") and id(method.__code__) in self._nn_module_method_ids() ): - unimplemented_v2( + unimplemented( gb_type="UnspecializedNNModuleVariable missing method", context=f"call_method: {self} {name} {args} {kwargs}", explanation=f"Dynamo does not support tracing method {name} of nn.Module {self.value}", diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 644c269a23a34..1870d366fe83e 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -28,7 +28,7 @@ from torch.fx.proxy import Proxy from .. import graph_break_hints -from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported +from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported from .base import VariableTracker from .user_defined import UserDefinedObjectVariable @@ -87,7 +87,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker method = getattr(self.value, name, None) if method is None: - unimplemented_v2( + unimplemented( gb_type="FakeScriptObject missing method implementation", context=f"value={self.value}, method={name}", explanation=f"TorchScript object {self.value} doesn't define the method {name}.", @@ -98,7 +98,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) if not callable(method): - unimplemented_v2( + unimplemented( gb_type="Attempted to access non-callable attribute of TorchScript object", context=f"value={self.value}, method={name}", explanation="Attribute accesses of TorchScript objects to non-callable attributes are not supported.", @@ -128,7 +128,7 @@ def call_method( args: Iterable[Any], kwargs: dict[str, Any], ) -> VariableTracker: - unimplemented_v2( + unimplemented( gb_type="Weird method call on TorchScript object", context=f"value={self.value}, method={name}", explanation=( diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 79a0d0eb9ba23..5c2d00ec01df2 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -9,7 +9,7 @@ from .. import graph_break_hints from ..bytecode_transformation import create_call_function -from ..exc import TYPE_CHECKING, unimplemented_v2 +from ..exc import TYPE_CHECKING, unimplemented from ..graph_bytecode_inputs import get_external_object_by_index from .base import VariableTracker from .constant import ConstantVariable @@ -389,7 +389,7 @@ def call_method( method_name = ( f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}" ) - unimplemented_v2( + unimplemented( gb_type="Unsupported event method", context=str(name), explanation=f"Dynamo doesn't support tracing the {method_name} method. " diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index d44f5171217d0..ca57d4e7e8783 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -46,7 +46,7 @@ from .. import config, graph_break_hints, variables from .._trace_wrapped_higher_order_op import trace_wrapped from ..exc import ( - unimplemented_v2, + unimplemented, UnknownPropertiesDuringBackwardTrace, UserError, UserErrorType, @@ -390,7 +390,7 @@ def method_attr_is_nested(self, tx): return ConstantVariable.create(self.is_nested) def method_attr_retain_grad(self, tx): - unimplemented_v2( + unimplemented( gb_type="Tensor.retain_grad() with AOTDispatcher", context=f"var_getattr {self} retain_grad", explanation="`Tensor.retain_grad()` does not work with AOTDispatcher.", @@ -404,7 +404,7 @@ def method_attr_data(self, tx): def method_attr_grad_fn(self, tx): if self.has_grad_fn: - unimplemented_v2( + unimplemented( gb_type="Tensor with grad_fn()", context=f"var_getattr {self} grad_fn", explanation="Dynamo does not support tracing tensors with a grad_fn directly.", @@ -451,7 +451,7 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): def var_getattr(self, tx: "InstructionTranslator", name): if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): - unimplemented_v2( + unimplemented( gb_type="Strict mode banned op", context=f"var_getattr {self} {name}", explanation=f"Getattr invocation '{name}' in strict mode is not supported.", @@ -541,7 +541,7 @@ def try_generic_attr_handling(): def call_id(self, tx): if not self.source: - unimplemented_v2( + unimplemented( gb_type="Unsupported call_id() without source", context=f"call_id {self}", explanation="call_id() not supported for sourceless TensorVariable.", @@ -553,7 +553,7 @@ def call_id(self, tx): try: _input_associated_real_value = eval(self.source.name(), scope) except Exception as exc: - unimplemented_v2( + unimplemented( gb_type="Error getting associated real value", context=f"call_id {self}", explanation="Dynamo encountered an error while trying to " @@ -563,7 +563,7 @@ def call_id(self, tx): ) if _input_associated_real_value is None: - unimplemented_v2( + unimplemented( gb_type="call_id() without associated real value", context=f"call_id {self}", explanation="Dynamo could not find an associated real value for the tensor.", @@ -639,7 +639,7 @@ def call_method( from .torch_function import can_dispatch_torch_function, dispatch_torch_function if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): - unimplemented_v2( + unimplemented( gb_type="Illegal method invocation in strict mode", context=f"call_method {self} {name} {args} {kwargs}", explanation="Dynamo currently does not support this method " @@ -683,7 +683,7 @@ def call_method( # discussions in #151432 for more details. # We graph break for now since this use case is uncommon. if name == "random_": - unimplemented_v2( + unimplemented( gb_type="Tensor.random_ op", context=f"Tensor.{name}({args=}, {kwargs=})", explanation="This is currently not supported.", @@ -693,7 +693,7 @@ def call_method( ], ) elif name == "uniform_" and "from" in kwargs: - unimplemented_v2( + unimplemented( gb_type="Tensor.uniform_ op called with `from` keyword", context=f"Tensor.{name}({args=}, {kwargs=})", explanation="This is currently not supported.", @@ -713,7 +713,7 @@ def call_method( if result: return result except TypeError as e: - unimplemented_v2( + unimplemented( gb_type="Unhandled args for method", context=f"call_method {self} {name} {args} {kwargs}", explanation="Dynamo encountered an error while calling " @@ -804,7 +804,7 @@ def method_is_floating_point(self): def method_is_inference(self): if config.fake_tensor_disable_inference_mode: - unimplemented_v2( + unimplemented( gb_type="Encountered tensor.is_inference() during tracing", context="", explanation="tensor.is_inference() is not supported", @@ -890,7 +890,7 @@ def method_as_subclass(self, cls): object(), var, mutation_type_cls=AttributeMutationNew ) return var - unimplemented_v2( + unimplemented( gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", context=f"{self}.as_subclass({cls})", explanation="Currently not supported", @@ -910,7 +910,7 @@ def method_element_size(self): def method_numpy(self, *, force=False): if not config.trace_numpy: - unimplemented_v2( + unimplemented( gb_type="Tensor.numpy() with trace_numpy=False", context=f"call_method {self} numpy", explanation="`Tensor.numpy()` was called, but the `trace_numpy` " @@ -921,7 +921,7 @@ def method_numpy(self, *, force=False): ], ) if not np: - unimplemented_v2( + unimplemented( gb_type="Tensor.numpy() without NumPy installed", context=f"call_method {self} numpy", explanation="`Tensor.numpy()` was called, but the NumPy library " @@ -970,7 +970,7 @@ def wrap(i, sub_proxy): torch.int32, torch.int64, ]: - unimplemented_v2( + unimplemented( gb_type="Tensor.tolist() with non-integer tensor", context=f"call_method {self} to_list", explanation="Dynamo currently does not support tracing " @@ -997,7 +997,7 @@ def wrap(i, sub_proxy): return VariableTracker.build(tx, out) def method_backward(self, *args, **kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.backward() call", context=f"call_method {self} backward {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.backward()`.", @@ -1014,7 +1014,7 @@ def method_item(self, *args, **kwargs): # We enable capture_scalar_outputs when full_graph=True by default. if not tx.one_graph and not config.capture_scalar_outputs: self._warn_capture_scalar_outputs() - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False", context=f"call_method {self} item {args} {kwargs}", explanation="Dynamo does not support tracing `Tensor.item()` " @@ -1147,7 +1147,7 @@ def method___setitem__(self, key, value): return ConstantVariable.create(None) def method_resize_(self, *args, **kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.resize_() call", context=f"call_method {self} resize_ {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.resize_()`.", @@ -1155,7 +1155,7 @@ def method_resize_(self, *args, **kwargs): ) def method_resize_as_(self, *args, **kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.resize_as_() call", context=f"call_method {self} resize_as_ {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.resize_as_()`.", @@ -1163,7 +1163,7 @@ def method_resize_as_(self, *args, **kwargs): ) def method_sparse_resize_(self, *args, **kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.sparse_resize_() call", context=f"call_method {self} sparse_resize_ {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_()`.", @@ -1171,7 +1171,7 @@ def method_sparse_resize_(self, *args, **kwargs): ) def method_sparse_resize_and_clear_(self, *args, **kwargs): - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.sparse_resize_and_clear_() call", context=f"call_method {self} sparse_resize_and_clear_ {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.sparse_resize_and_clear_()`.", @@ -1186,7 +1186,7 @@ def method_set_(self, *args, **kwargs): # overload and is used by FSDP. # graph-breaking on aten::set_source_Tensor_storage_offset for now, # unless we find that we need to make it work. - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.set_() call", context=f"call_method {self} set_ {args} {kwargs}", explanation="Dynamo currently does not support tracing `Tensor.set_()` " @@ -1318,7 +1318,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker): # would have no recourse - their forward traces just fine, but will fail at backwards unless # compiled_autograd is enabled. If compiled_autograd fails (there are a lot of failures today) # then they have nothing they can do except disable compile. - unimplemented_v2( + unimplemented( gb_type="Compilation of intermediate hooks requires compiled autograd", context=f"var_getattr {self} {name}", explanation="Dynamo must be in compiled_autograd to register hooks.", @@ -1368,7 +1368,7 @@ def method_requires_grad_(self, requires_grad=True): requires_grad = requires_grad.as_python_constant() if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: - unimplemented_v2( + unimplemented( gb_type="Unsupported Tensor.requires_grad_() call", context=f"call_method {self} requires_grad_", explanation="Dynamo does not support changes to a Tensor's " @@ -1560,14 +1560,14 @@ def insert_into_graph(): return ConstantVariable.create(int(r)) return insert_into_graph() elif name in ["base", "flags", "dtype"]: - unimplemented_v2( + unimplemented( gb_type="Unsupported ndarray attribute access", context=f"var_getattr {self} {name}", explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", hints=[], ) elif name == "__version__": - unimplemented_v2( + unimplemented( gb_type="Unsupported ndarray.__version__ access", context=f"var_getattr {self} {name}", explanation=f"Dynamo currently does not support tracing `ndarray.{name}`.", @@ -1591,7 +1591,7 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": - from ..exc import unimplemented_v2 + from ..exc import unimplemented from ..utils import numpy_method_wrapper args, kwargs = self.patch_args(name, args, kwargs) @@ -1611,7 +1611,7 @@ def call_method( isinstance(dtype_arg, BuiltinVariable) and dtype_arg.fn is object ) if is_object_str or is_object_type: - unimplemented_v2( + unimplemented( gb_type="ndarray.astype(object)", context=f"call_method {self} {name} {args} {kwargs}", explanation=( @@ -1625,7 +1625,7 @@ def call_method( # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) if name in ("tostring", "tobytes", "__delattr__"): - unimplemented_v2( + unimplemented( gb_type="Unsupported ndarray method call", context=f"call_method {self} {name} {args} {kwargs}", explanation=f"`ndarray.{name}()` is not modelled in `torch._numpy`.", @@ -1713,7 +1713,7 @@ def call_function( tx, data, self.value, self.source ) else: - unimplemented_v2( + unimplemented( gb_type="Calling subclass default constructor with more than tensor argument", context=f"{self.value}(args={args}, kwargs={kwargs})", explanation="Currently not supported", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index a687d77c186db..895acfd56e80b 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -52,7 +52,7 @@ tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces -from ..exc import raise_observed_exception, unimplemented_v2 +from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -605,7 +605,7 @@ def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): @register(torch.is_inference_mode_enabled) def handle_is_inference_mode_enabled(self, tx: "InstructionTranslator"): - unimplemented_v2( + unimplemented( gb_type="Encountered torch.is_inference_mode_enabled during tracing", context="", explanation="torch.is_inference_mode_enabled() is not supported", @@ -654,7 +654,7 @@ def handle_torch_compile(self, tx: "InstructionTranslator", *args, **kwargs): # torch.compile is a no-op in dynamo return args[0] - unimplemented_v2( + unimplemented( gb_type="torch.compile call with > 1 args", context=f"args={args}, kwargs={kwargs}", explanation="Attempted to call `torch.compile` with > 1 args. Dynamo does not support this.", @@ -690,7 +690,7 @@ def handle_use_deterministic_algorithms( ): # pyrefly: ignore [missing-attribute] if warn_only and warn_only.as_python_constant(): - unimplemented_v2( + unimplemented( gb_type="Attempted to use torch.use_deterministic_algorithms(warn_only=True)", context=f"mode={mode}, warn_only={warn_only}", explanation="Dynamo does not support this.", @@ -749,7 +749,7 @@ def handle_device_interface_stream(self, tx: "InstructionTranslator", stream): @register(torch.from_numpy) def handle_from_numpy(self, tx: "InstructionTranslator", *args): if not config.trace_numpy: - unimplemented_v2( + unimplemented( gb_type="call `torch.from_numpy` with `torch._dynamo.config.trace_numpy=False`", context=f"trace_numpy={config.trace_numpy}", explanation=( @@ -761,7 +761,7 @@ def handle_from_numpy(self, tx: "InstructionTranslator", *args): ], ) if not np: - unimplemented_v2( + unimplemented( gb_type="`torch.from_numpy` with NumPy unavailable", context="", explanation="Attempted to call `torch.numpy` but NumPy could not be imported.", @@ -982,7 +982,7 @@ def handle_nested_tensor( from .lists import BaseListVariable if layout and layout.as_python_constant() == torch.strided: - unimplemented_v2( + unimplemented( gb_type="Attempted to use strided NestedTensor", context=f"layout={layout}", explanation="Dynamo does not support this.", @@ -992,7 +992,7 @@ def handle_nested_tensor( ], ) if not isinstance(tensor_list, BaseListVariable): - unimplemented_v2( + unimplemented( gb_type="Attempted to use `nested_tensor` with non-list input", context=f"tensor_list={tensor_list}", explanation="Dynamo does not support this.", @@ -1009,7 +1009,7 @@ def handle_one_hot(self, tx: "InstructionTranslator", *args, **kwargs): and args[1].is_python_constant() and args[1].as_python_constant() == -1 ): - unimplemented_v2( + unimplemented( gb_type="Attempted to use `torch.nn.functional.one_hot` with data-dependent output shape", context=f"args={args}, kwargs={kwargs}", explanation="Dynamo does not support this.", @@ -1187,7 +1187,7 @@ def handle_pop_torch_function( ): assert not args and not kwargs if not tx.symbolic_torch_function_state.mode_stack: - unimplemented_v2( + unimplemented( gb_type="Attempted to pop from empty torch function mode stack", context="", explanation="Called `torch._C._pop_torch_function_stack` when torch function mode stack is empty.", @@ -1236,7 +1236,7 @@ def handle_get_stack_at(self, tx: "InstructionTranslator", *args, **kwargs): @register(torch.get_device_module.__wrapped__) def handle_get_device_module(self, tx, *args, **kwargs): if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): - unimplemented_v2( + unimplemented( gb_type="improper torch.get_device_module arguments", context=f"args={args}, kwargs={kwargs}", explanation="torch.get_device_module accepts 1 optional argument `device`", @@ -1253,7 +1253,7 @@ def handle_get_device_module(self, tx, *args, **kwargs): device = None module = torch.get_device_module(device) except Exception as e: - unimplemented_v2( + unimplemented( gb_type="bad device argument to torch.get_device_module", context=f"args={args}, kwargs={kwargs}", explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", @@ -1278,7 +1278,7 @@ def handle_get_device_module(self, tx, *args, **kwargs): @register(torch.accelerator.current_stream, torch.cuda.current_stream) def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): - unimplemented_v2( + unimplemented( gb_type="unsupported arguments to torch.accelerator.current_stream", context=f"args={args}, kwargs={kwargs}", explanation="torch.accelerator.current_stream accepts one optional argument `device`", @@ -1296,7 +1296,7 @@ def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): return tx.symbolic_stream_state.cur_stream(device) except Exception as e: - unimplemented_v2( + unimplemented( gb_type="bad device argument to torch.accelerator.current_stream", context=f"args={args}, kwargs={kwargs}", explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", @@ -1360,7 +1360,7 @@ def handle_check(self, tx: "InstructionTranslator", *args, **kwargs): not isinstance(message_vt, NestedUserFunctionVariable) or message_vt.has_closure() ): - unimplemented_v2( + unimplemented( gb_type="Can't extract message from torch._check()", context=str(message_vt), explanation=( @@ -1446,7 +1446,7 @@ def call_function( arg_type = flat_arg_vt.python_type() if not is_graphable_type(arg_type): type_name = flat_arg_vt.python_type().__qualname__ - unimplemented_v2( + unimplemented( gb_type="Invalid input type for nonstrict_trace-ed function", context=f"Encountered input of type <{type_name}>.", explanation=( @@ -1480,7 +1480,7 @@ def call_function( import torch.utils._pytree as pytree if pytree.is_constant_class(typ): - unimplemented_v2( + unimplemented( gb_type="Input marked with `pytree.register_constant` constructed in the `torch.compile` region", context=f"Input={input_spec_vt}, offending type <{type_name}>.", explanation=( @@ -1495,7 +1495,7 @@ def call_function( from_exc=e, ) else: - unimplemented_v2( + unimplemented( gb_type="Invalid use of pytree_flatten with nonstrict_trace-ed function", context=f"Input={input_spec_vt}, offending type <{type_name}>.", explanation=( @@ -1560,7 +1560,7 @@ def patched_fn(*args, **kwargs): # From `flat_apply` assert on output type. torch._dynamo.exc.TorchRuntimeError, ): - unimplemented_v2( + unimplemented( gb_type="Unsupported output type for nonstrict_trace-ed function", context=f"Function: {fn.__name__}", explanation=( @@ -1612,7 +1612,7 @@ def patched_fn(*args, **kwargs): and torch.Tag.inplace_view in getattr(fn, fn.overloads()[0]).tags ): - unimplemented_v2( + unimplemented( gb_type="Inplace op on input tensor", context="", explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", @@ -1647,7 +1647,7 @@ def patched_fn(*args, **kwargs): For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ log.warning(msg) - unimplemented_v2( + unimplemented( gb_type="Attempted to call torch in-graph function on only torch.SymInt arguments", context=f"fn={self.value}, args={args}, kwargs={kwargs}", explanation=( @@ -1715,7 +1715,7 @@ def patched_fn(*args, **kwargs): and "requires_grad" in kwargs and kwargs["requires_grad"].as_python_constant() ): - unimplemented_v2( + unimplemented( gb_type="Attempted to use tensor creation function with requires_grad=True", context=f"fn={self.value}, args={args}, kwargs={kwargs}", explanation="Dynamo does not support this.", @@ -1755,7 +1755,7 @@ def patched_fn(*args, **kwargs): if saved_out_shape != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. - unimplemented_v2( + unimplemented( gb_type="Shape mismatch with out= list of tensor variants", context=f"fn={self.value}, args={args}, kwargs={kwargs}", explanation=( @@ -1769,7 +1769,7 @@ def patched_fn(*args, **kwargs): if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument - unimplemented_v2( + unimplemented( gb_type="Attempted to call op with non-contiguous `out=` list of tensors", context=f"self.value={self.value}, args={args}, kwargs={kwargs}", explanation="Dynamo does not support this.", @@ -1784,7 +1784,7 @@ def patched_fn(*args, **kwargs): if saved_out_shapes != fake_out.shape: # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. - unimplemented_v2( + unimplemented( gb_type="Shape mismatch with out= tensor variant", context=f"fn={self.value}, args={args}, kwargs={kwargs}", explanation=( @@ -1798,7 +1798,7 @@ def patched_fn(*args, **kwargs): if not torch._prims_common.is_contiguous(fake_out): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument - unimplemented_v2( + unimplemented( gb_type="Attempted to call op with non-contiguous `out=` tensor", context=f"self.value={self.value}, args={args}, kwargs={kwargs}", explanation="Dynamo does not support this.", @@ -1829,7 +1829,7 @@ def handle_ntuple(value): torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), ) else: - unimplemented_v2( + unimplemented( gb_type="Attempted to use `torch.nn.modules.utils._ntuple` with unsupported argument type", context=f"value={value}", explanation="Dynamo does not support this.", @@ -1847,7 +1847,7 @@ def handle_ntuple(value): def call_nn_parameter(cls, tx, data=None, requires_grad=True): """A call to torch.nn.Parameter() gets lifted to before the graph""" if tx.export: - unimplemented_v2( + unimplemented( gb_type="Attempted to use `torch.nn.Parameter()` with export", context="", explanation="Dynamo does not support this.", @@ -1861,7 +1861,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): try: requires_grad = requires_grad.as_python_constant() except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="non-constant `requires_grad` argument to `torch.nn.Parameter`", context=f"requires_grad={requires_grad}", explanation="Dynamo does not support this.", @@ -1872,7 +1872,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) if not isinstance(data, variables.TensorVariable): - unimplemented_v2( + unimplemented( gb_type="`torch.nn.Parameter()` with unsupported data type", context=f"data={data}", explanation="Called `torch.nn.Parameter()` with non-Tensor argument.", @@ -1889,7 +1889,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if config.graph_break_on_nn_param_ctor: # Need user to manually move since we cannot - unimplemented_v2( + unimplemented( gb_type="Attempted to use `torch.nn.Parameter()` constructor with Dynamo", context="", explanation="Dynamo does not support this", @@ -1906,7 +1906,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): TensorWithTFOverrideVariable, # pyrefly: ignore [missing-attribute] ) or is_traceable_wrapper_subclass_type(data.class_type): - unimplemented_v2( + unimplemented( gb_type="Attempted to use torch.nn.Parameter constructor with tensor subclass", context=str(data), explanation="Dynamo does not support this.", @@ -1916,7 +1916,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): ) if not can_convert_to_tracable_parameter(): - unimplemented_v2( + unimplemented( gb_type="`torch.nn.Parameter`: cannot convert to traceable tracable", context="", explanation="convert_tracable_parameter is set to False.", @@ -1934,7 +1934,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): # pyrefly: ignore [missing-attribute] device = data.var_getattr(tx, "device").as_python_constant() except NotImplementedError as e: - unimplemented_v2( + unimplemented( gb_type="`torch.nn.Parameter` with non-constant Tensor attributes", context=f"data={data}", explanation="Dynamo does not support this.", @@ -2000,7 +2000,7 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad data_node = data.as_proxy().node if data_node.op not in ("placeholder", "get_attr"): - unimplemented_v2( + unimplemented( gb_type="Unexpected type of data placeholder op for parameter construction", context=f"data_node.op={data_node.op}", explanation="Data node op should be placeholder or get_attr.", diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4d0f0b4fae8ab..c7254afdfebfc 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -44,7 +44,7 @@ from torch.utils._device import DeviceContext from .. import graph_break_hints -from ..exc import unimplemented_v2 +from ..exc import unimplemented from ..guards import GuardBuilder, install_guard from ..polyfills import NoEnterTorchFunctionMode from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource @@ -558,7 +558,7 @@ def dispatch_torch_function( if not (isinstance(res, ConstantVariable) and res.value is NotImplemented): return res - unimplemented_v2( + unimplemented( gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code", context=f"{fn=}, {args=}, {kwargs=}", explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented", @@ -626,7 +626,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker # I think only `_base` is breaking because we aren't modelling view # relationship perfectly in some scenarios. if name in banned_attrs: - unimplemented_v2( + unimplemented( gb_type="Unsupported tensor subclass attribute access", context=f"{name}", explanation="`torch.compile` currently can't trace this", @@ -686,7 +686,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) elif attr_is_overridden: - unimplemented_v2( + unimplemented( gb_type="Unsupported tensor subclass overridden attribute access", context=f"{name}", explanation="`torch.compile` only support tracing certain types of overridden tensor subclass attributes", @@ -734,7 +734,7 @@ def call_method( import torch if _is_attr_overridden(tx, self, name): - unimplemented_v2( + unimplemented( gb_type="Tensor subclass overridden method call", context=f"{name}", explanation="`torch.compile` currently can't trace this", diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index a65ee6b1e0bf6..7709850d22d8b 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -56,7 +56,7 @@ ObservedTypeError, ObservedUserStopIteration, raise_observed_exception, - unimplemented_v2, + unimplemented, ) from ..graph_bytecode_inputs import get_external_object_by_index from ..guards import GuardBuilder, install_guard @@ -459,7 +459,7 @@ def call_method( args[1:], ) elif name == "__setattr__" and self.ban_mutation: - unimplemented_v2( + unimplemented( gb_type="Class attribute mutation when the __dict__ was already materialized", context=str(self.value), explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", @@ -515,7 +515,7 @@ def call_function( ) elif is_typeddict(self.value): if self.value.__optional_keys__: - unimplemented_v2( + unimplemented( gb_type="TypedDict with optional keys", context=str(self.value), explanation="Dyanmo does not support tracing TypedDict with optional keys", @@ -534,7 +534,7 @@ def deque_signature(iterable=None, maxlen=None): try: bound_args = inspect.signature(deque_signature).bind(*args, **kwargs) except TypeError as e: - unimplemented_v2( + unimplemented( gb_type="collections.deque() with bad arguments", context=f"args={args}, kwargs={kwargs}", explanation="Detected call to collections.deque() with bad arguments.", @@ -549,7 +549,7 @@ def deque_signature(iterable=None, maxlen=None): if not bound_args.arguments["iterable"].has_force_unpack_var_sequence( tx ): - unimplemented_v2( + unimplemented( gb_type="collections.deque() with bad iterable argument", context=f"args={args}, kwargs={kwargs}", explanation="Call to collections.deque() has an iterable argument that Dynamo cannot " @@ -578,7 +578,7 @@ def deque_signature(iterable=None, maxlen=None): return variables.WeakRefVariable(args[0], callback) elif self.value is functools.partial: if not args: - unimplemented_v2( + unimplemented( gb_type="missing args to functools.partial", context="", explanation="functools.partial requires at least one argument", @@ -636,7 +636,7 @@ def deque_signature(iterable=None, maxlen=None): ): # We are not changing the behavior of Dynamo as these function were # already ignored on trace_rules.py before #136033 landed - unimplemented_v2( + unimplemented( gb_type="unsupported contextlib.* API", context=f"{self.value}", explanation=f"{self.value} not supported. This may be due to its use of " @@ -651,7 +651,7 @@ def deque_signature(iterable=None, maxlen=None): args[0], (BaseUserFunctionVariable, TorchCtxManagerClassVariable) ): if not torch._dynamo.config.enable_trace_contextlib: - unimplemented_v2( + unimplemented( gb_type="attempted to trace contextlib.contextmanager", context=f"args={args}", explanation="Tracing contextlib.contextmanager is disabled.", @@ -1115,7 +1115,7 @@ def call_method( if torch._dynamo.config.enable_faithful_generator_behavior and isinstance( self.value, types.GeneratorType ): - unimplemented_v2( + unimplemented( gb_type="call_method on generator", context=f"object={self.value}, method={name}, args={args}, kwargs={kwargs}", explanation="Detected a method call to a user-defined generator object. " @@ -1154,7 +1154,7 @@ def method_setattr_standard( try: name = name.as_python_constant() except NotImplementedError: - unimplemented_v2( + unimplemented( gb_type="non-const setattr name on user-defined object", context=f"object={self}, name={name}, value={value}", explanation="Detected a call to `setattr` of a user-defined object with a non-constant name.", @@ -1280,7 +1280,7 @@ def call_function( ).call_function(tx, [var], kwargs) if self.source is None: - unimplemented_v2( + unimplemented( gb_type="attempted to call sourceless user-defined object as a method", context=f"object={self.value}, function={func}, args={args}, kwargs={kwargs}", explanation="Dynamo does not support this.", @@ -1410,7 +1410,7 @@ def get_source_by_walking_mro(self, name): ) return out_source - unimplemented_v2( + unimplemented( gb_type="could not find name in object's mro", context=f"name={name}, object type={type(self.value)}, mro={type(self.value).__mro__}", explanation=f"Could not find name `{name}` in mro {type(self.value).__mro__}", @@ -1506,7 +1506,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return out elif getattr_fn is not None: - unimplemented_v2( + unimplemented( gb_type="User-defined object with non-function __getattr__", context=f"object={self.value}, name={name}, getattr_fn={getattr_fn}", explanation=f"Found a non-function __getattr__ {getattr_fn} from a user-defined object {self.value} " @@ -1632,7 +1632,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): if isinstance(subobj, types.MethodType): if dynamic_subobj.__self__ is not self.value: if not isinstance(dynamic_subobj.__func__, types.FunctionType): - unimplemented_v2( + unimplemented( gb_type="User-defined object method with non-function __func__", context=f"object={self.value}, name={name}, method={dynamic_subobj}, " f"method.__self__={dynamic_subobj.__self__}, method.__func__={dynamic_subobj.__func__}", diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index ded569f70ef64..4ede1d7234066 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1413,10 +1413,10 @@ def tensor_visitor_fn( # TODO: Handle this better in Dynamo? # There are checks there now, but this can still be triggered by a dense # tensor graph input that is a view of a strided NT. - from torch._dynamo.exc import unimplemented_v2 + from torch._dynamo.exc import unimplemented # NOTE this graph break will NOT be present in Dynamo's graph break registry - unimplemented_v2( + unimplemented( gb_type="attempted to apply meta conversion to strided nested tensor", context=str(t), explanation="This is not supported.", @@ -1454,9 +1454,9 @@ def tensor_visitor_fn( r = self._backward_error(r) elif t.is_functorch_wrapped: if t.is_view: - from torch._dynamo.exc import unimplemented_v2 + from torch._dynamo.exc import unimplemented - unimplemented_v2( + unimplemented( gb_type="attempted to apply meta conversion to view functorch tensor", context=str(t), explanation="This is not supported.", From c2924bbafa253edefab387aea7ba577de2f40c09 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 13:34:48 -0800 Subject: [PATCH 250/651] [dynamo] replace raise Unsupported(...) with unimplemented(...) (#167255) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167255 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos, https://github.com/zou3519 ghstack dependencies: #167150 --- torch/_dynamo/eval_frame.py | 13 +++- torch/_dynamo/graph_break_registry.json | 82 +++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 9 ++- torch/_dynamo/variables/functions.py | 7 +- torch/_dynamo/variables/higher_order_ops.py | 19 +++-- torch/_dynamo/variables/nn_module.py | 18 ++--- torch/_dynamo/variables/optimizer.py | 9 ++- torch/_dynamo/variables/sdpa.py | 16 ++-- torch/_dynamo/variables/torch.py | 9 ++- 9 files changed, 153 insertions(+), 29 deletions(-) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 9c9076f5a99c0..9b9572620db14 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1785,16 +1785,21 @@ def check_optional_input_and_error(f_sig: inspect.Signature) -> None: # Check if function has optional input. for name, param in f_sig.parameters.items(): if param.default is not inspect.Parameter.empty: - from torch._dynamo.exc import Unsupported + import torch._dynamo.graph_break_hints as graph_break_hints + from torch._dynamo.exc import unimplemented log.error( "Parameter %s is optional with a default value of %s", name, param.default, ) - raise Unsupported( - "Tracing through optional input is not supported yet", - case_name="optional_input", + unimplemented( + gb_type="rewrite_signature: cannot trace optional function input", + context="", + explanation=f"Parameter {name} is optional with a default value of {param.default}. This is not supported yet.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], ) def produce_matching( diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 638487e417e63..5603fc166782d 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3575,5 +3575,87 @@ "Explanation": "Expected backward function to be a function or method.", "Hints": [] } + ], + "GB0353": [ + { + "Gb_type": "rewrite_signature: cannot trace optional function input", + "Context": "", + "Explanation": "Parameter {name} is optional with a default value of {param.default}. This is not supported yet.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0354": [ + { + "Gb_type": "failed to find name in frame builtins", + "Context": "", + "Explanation": "Failed to find name `{argval}` in frame's builtins.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0355": [ + { + "Gb_type": "non-single Tensor return unsupported", + "Context": "api: {api}, ret: {ret}", + "Explanation": "{api} over function that returns something other than one Tensor.", + "Hints": [] + } + ], + "GB0356": [ + { + "Gb_type": "failed to handle argument for FlexAttentionBackward HOP", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "Missing Dynamo support for FlexAttentionBackward HOP argument.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], + "GB0357": [ + { + "Gb_type": "UnspecializedNNModuleVariable wrapped around ScriptModules unsupported", + "Context": "str(value)", + "Explanation": "ScriptModules aren't supported in UnspecializedNNModuleVariable because their .forward function isn't a static member of their type.", + "Hints": [ + "This graph break may be difficult to debug. Please report an issue to PyTorch for assistance." + ] + } + ], + "GB0358": [ + { + "Gb_type": "optimizer: pending mutation on parameter", + "Context": "variable: {variable}, parameter: {p}", + "Explanation": "Pending mutations on a parameter (e.g. due to using closure) require a graph break.", + "Hints": [] + } + ], + "GB0359": [ + { + "Gb_type": "unsupported torch._C._SDPAParams attribute", + "Context": "name: {name}", + "Explanation": "Unable to fetch attribute {name} from torch._C._SDPAParams.", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0360": [ + { + "Gb_type": "torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", + "Context": "expr: {expr}", + "Explanation": "Expected `expr` to be a symbolic variable or constant.", + "Hints": [] + } + ], + "GB0361": [ + { + "Gb_type": "triton kernel unsupported feature", + "Context": "", + "Explanation": "Encountered triton kernel unsupported feature: {msg}", + "Hints": [] + } ] } diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 179f0ed067552..f5adc7fcfa379 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1971,7 +1971,14 @@ def IMPORT_FROM(self, inst: Instruction) -> None: @cache_method def load_builtin_from_argval(self, argval: Any) -> VariableTracker: if argval not in self.f_builtins: - raise Unsupported(f"name '{argval}' is not defined") + unimplemented( + gb_type="failed to find name in frame builtins", + context="", + explanation=f"Failed to find name `{argval}` in frame's builtins.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) val = self.f_builtins[argval] if callable(val): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 8411441724d3c..ed61c5bfa079b 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -2322,7 +2322,12 @@ def call_function( class DynamoTritonHOPifier(TritonHOPifier): def raise_unsupported(self, msg: str) -> Never: - raise Unsupported(msg) + unimplemented( + gb_type="triton kernel unsupported feature", + context="", + explanation=f"Encountered triton kernel unsupported feature: {msg}", + hints=[], + ) def is_callable(self, maybe_callable: VariableTracker) -> bool: return isinstance( diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index b713b02c4e41a..077132757c95d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2122,8 +2122,11 @@ def non_single_tensor_return_unsupported(api, ret): from . import TensorVariable if not isinstance(ret, TensorVariable): - raise Unsupported( - f"{api} over function that returns something other than one Tensor" + unimplemented( + gb_type="non-single Tensor return unsupported", + context=f"api: {api}, ret: {ret}", + explanation=f"{api} over function that returns something other than one Tensor.", + hints=[], ) @@ -3042,9 +3045,15 @@ def _call_function( p_args = tuple(self.to_proxy(tx, arg) for arg in args) p_kwargs = {key: self.to_proxy(tx, arg) for key, arg in kwargs.items()} except (NotImplementedError, Unsupported) as err: - raise Unsupported( - "Missing Dynamo support for FlexAttentionBackward HOP argument. Please file an issue." - ) from err + unimplemented( + gb_type="failed to handle argument for FlexAttentionBackward HOP", + context=f"args: {args}, kwargs: {kwargs}", + explanation="Missing Dynamo support for FlexAttentionBackward HOP argument.", + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + from_exc=err, + ) return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index b58580cd61240..e754699d862ad 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -34,12 +34,7 @@ import torch.nn from .. import graph_break_hints, trace_rules, variables -from ..exc import ( - raise_observed_exception, - unimplemented, - UnspecializeRestartAnalysis, - Unsupported, -) +from ..exc import raise_observed_exception, unimplemented, UnspecializeRestartAnalysis from ..guards import GuardBuilder, install_guard from ..mutation_guard import GenerationTracker from ..source import ( @@ -960,9 +955,14 @@ class UnspecializedNNModuleVariable(UserDefinedObjectVariable): def __init__(self, value, **kwargs) -> None: if type(value) is torch.jit._script.RecursiveScriptModule: - raise Unsupported( - "ScriptModules aren't supported in UnspecializedNNModuleVariable" - " because their .forward function isn't a static member of their type" + unimplemented( + gb_type="UnspecializedNNModuleVariable wrapped around ScriptModules unsupported", + context=str(value), + explanation="ScriptModules aren't supported in UnspecializedNNModuleVariable" + " because their .forward function isn't a static member of their type.", + hints=[ + *graph_break_hints.DIFFICULT, + ], ) if "value_type" in kwargs: lazy_value_to_become = getattr(kwargs["value_type"], "cls_to_become", None) diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index c09cc2163a5f4..fd7ccf9cc6e68 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -171,9 +171,14 @@ def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None: side_effects = tx.output.side_effects variable = side_effects.id_to_variable.get(id(p), None) if variable and side_effects.has_pending_mutation(variable): - from ..exc import Unsupported + from ..exc import unimplemented - raise Unsupported("Pending mutation on parameter") + unimplemented( + gb_type="optimizer: pending mutation on parameter", + context=f"variable: {variable}, parameter: {p}", + explanation="Pending mutations on a parameter (e.g. due to using closure) require a graph break.", + hints=[], + ) def _set_capturable(self, tx: "InstructionTranslator") -> None: from . import LazyVariableTracker diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 629bf094dc951..1a7006f5d56ab 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -7,7 +7,7 @@ from torch.fx.proxy import Proxy from ..bytecode_transformation import create_call_function -from ..exc import Unsupported +from ..exc import unimplemented from ..source import AttrSource from .base import VariableTracker @@ -71,10 +71,16 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker try: getattr_static(torch._C._SDPAParams, name) except AttributeError: - # Using raise from is too verbose here - raise Unsupported( - f"Unsupported torch._C._SDPAParams attribute {name}" - ) from None + import torch._dynamo.graph_break_hints as graph_break_hints + + unimplemented( + gb_type="unsupported torch._C._SDPAParams attribute", + context=f"name: {name}", + explanation=f"Unable to fetch attribute {name} from torch._C._SDPAParams.", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) if self.source is not None: diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 895acfd56e80b..eb9580e0a05c8 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1072,9 +1072,14 @@ def guard_scalar(self, tx: "InstructionTranslator", expr): elif isinstance(expr, ConstantVariable): val = expr.value else: - raise torch._dynamo.exc.Unsupported("branch not supported") + unimplemented( + gb_type="torch.fx.experimental.symbolic_shapes.guard_scalar branch not supported", + context=f"expr: {expr}", + explanation="Expected `expr` to be a symbolic variable or constant.", + hints=[], + ) return variables.ConstantVariable.create( - # pyrefly: ignore [bad-argument-type] + # pyrefly: ignore [bad-argument-type, unbound-name] torch.fx.experimental.symbolic_shapes.guard_scalar(val) ) From 29d6bb79e1b7d08f1a6988f4f8f94c8059978d8b Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sat, 8 Nov 2025 03:09:11 +0000 Subject: [PATCH 251/651] Use context managers (SIM115) (#166928) This PR changes code to use context managers if possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166928 Approved by: https://github.com/Lucaskabela --- pyproject.toml | 4 +- torch/_dynamo/variables/higher_order_ops.py | 25 ++-- torch/_inductor/codecache.py | 12 +- .../examples/fsdp_checkpoint_example.py | 17 ++- torch/profiler/profiler.py | 19 ++- torch/sparse/_triton_ops_meta.py | 10 +- .../distributed/nn/api/remote_module_test.py | 10 +- .../distributed/rpc/dist_autograd_test.py | 13 +- .../_internal/distributed/rpc/jit/rpc_test.py | 15 +- .../_internal/distributed/rpc/rpc_test.py | 60 ++++---- torch/testing/_internal/jit_utils.py | 135 +++++++++--------- 11 files changed, 157 insertions(+), 163 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4cf3562886fd9..21a1f2ec1e3e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -172,9 +172,9 @@ ignore = [ "SIM102", "SIM103", "SIM112", # flake8-simplify code styles "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason "SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression - "SIM110", + "SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all. "SIM114", # Combine `if` branches using logical `or` operator - "SIM115", + "SIM115", # Checks for cases where files are opened without using a context manager. "SIM116", # Disable Use a dictionary instead of consecutive `if` statements "SIM117", "SIM118", diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 077132757c95d..89c2a7451a771 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3598,19 +3598,18 @@ def unwrap_proxy(x): # The fwd outputs (tensor's example_value) need to be inferred from fake tensor prop to get the correct attributes # (e.g, tensor.requires_grad), which would be used by downstream Dynamo tracing. # Since there can be other ops like Triton kernels, which depends on python dispatcher, we have to enable it. - with enable_python_dispatcher(): - with tx.output.fake_mode: - fake_args = ( - tx.output.nn_modules[fwd_node.node.name], - tx.output.nn_modules[bwd_node.node.name], - *( - [ - _get_fake_value(arg) - for arg in filtered_args + list(fwd_freevars.keys()) - ] - ), - ) - example_value = autograd_function_apply(*fake_args, **kwargs) + with enable_python_dispatcher(), tx.output.fake_mode: + fake_args = ( + tx.output.nn_modules[fwd_node.node.name], + tx.output.nn_modules[bwd_node.node.name], + *( + [ + _get_fake_value(arg) + for arg in filtered_args + list(fwd_freevars.keys()) + ] + ), + ) + example_value = autograd_function_apply(*fake_args, **kwargs) return wrap_fx_proxy( tx=tx, diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 9e57c498abbf1..b0bea9d2d6bb9 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3798,11 +3798,13 @@ def cutlass_key() -> bytes: Note: OSS and fbcode will have different keys. """ if config.is_fbcode(): - with importlib.resources.path( - "cutlass_library", "src_hash.txt" - ) as resource_path: - with open(resource_path) as resource_file: - return resource_file.read().encode() + with ( + importlib.resources.path( + "cutlass_library", "src_hash.txt" + ) as resource_path, + open(resource_path) as resource_file, + ): + return resource_file.read().encode() combined_hash = hashlib.sha256() build_code_hash([config.cuda.cutlass_dir], "", combined_hash) diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py index 7d57b55c22fd6..a20ac912f8767 100644 --- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py +++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py @@ -37,15 +37,14 @@ def init_model(): def print_params(stage, model_1, model_2, optim_1, optim_2): - with FSDP.summon_full_params(model_1): - with FSDP.summon_full_params(model_2): - print( - f"{stage} --- rank: {dist.get_rank()}\n" - f"model.weight: {model_1.weight}\n" - f"model_2.weight:{model_2.weight}\n" - f"model.bias: {model_1.bias}\n" - f"model_2.bias: {model_2.bias}\n" - ) + with FSDP.summon_full_params(model_1), FSDP.summon_full_params(model_2): + print( + f"{stage} --- rank: {dist.get_rank()}\n" + f"model.weight: {model_1.weight}\n" + f"model_2.weight:{model_2.weight}\n" + f"model.bias: {model_1.bias}\n" + f"model_2.bias: {model_2.bias}\n" + ) print( f"{stage} --- rank: {dist.get_rank()}\n" diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 893b4078cb9ce..645667bb81bb6 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -271,13 +271,13 @@ def export_chrome_trace(self, path: str): "Profiler must be initialized before exporting chrome trace" ) if path.endswith(".gz"): - fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) - fp.close() - retvalue = self.profiler.export_chrome_trace(fp.name) - with open(fp.name, "rb") as fin: - with gzip.open(path, "wb") as fout: - fout.writelines(fin) - os.remove(fp.name) + with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp: + fp.close() + retvalue = self.profiler.export_chrome_trace(fp.name) + with open(fp.name, "rb") as fin: + with gzip.open(path, "wb") as fout: + fout.writelines(fin) + os.remove(fp.name) return retvalue else: return self.profiler.export_chrome_trace(path) @@ -454,9 +454,8 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non self.mem_tl.export_memory_timeline_raw(fp.name, device) else: self.mem_tl.export_memory_timeline(fp.name, device) - with open(fp.name) as fin: - with gzip.open(path, "wt") as fout: - fout.writelines(fin) + with open(fp.name) as fin, gzip.open(path, "wt") as fout: + fout.writelines(fin) os.remove(fp.name) else: self.mem_tl.export_memory_timeline(path, device) diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 38749d00f0eb4..ae8e5f4066e27 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -217,9 +217,8 @@ def update(op, device_name, version, key, value): def dump(): """Store the current runtime db state to the module file.""" current_file = inspect.getfile(dump) - f = open(current_file) - current_content = f.read() - f.close() + with open(current_file) as f: + current_content = f.read() begin_data_str = "# BEGIN GENERATED DATA\n" begin_data_index = current_content.find(begin_data_str) end_data_index = current_content.find(" # END GENERATED DATA\n") @@ -250,9 +249,8 @@ def sort_key(key): data_part.append(" },") new_content = part1 + "\n".join(data_part) + "\n" + part2 if current_content != new_content: - f = open(current_file, "w") - f.write(new_content) - f.close() + with open(current_file, "w") as f: + f.write(new_content) def minimize( diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index 79c55f5b8847b..af136fb8722d1 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -477,11 +477,13 @@ def test_remote_module_py_pickle_not_supported_script(self): for remote_module in self._create_remote_module_iter( dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ): - with TemporaryFileName() as fname: - with self.assertRaisesRegex( + with ( + TemporaryFileName() as fname, + self.assertRaisesRegex( torch.jit.Error, "can only be pickled when using RPC" - ): - torch.save(remote_module, fname) + ), + ): + torch.save(remote_module, fname) class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest): diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 1b371d3ee6ea0..1abadd33309da 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -1283,13 +1283,14 @@ def test_autograd_context(self): @dist_init def test_nested_context(self): - with dist_autograd.context(): - # Nested contexts not supported. - with self.assertRaisesRegex( + with ( + dist_autograd.context(), + self.assertRaisesRegex( RuntimeError, "Already have an autograd context id for this thread" - ): - with dist_autograd.context(): - pass + ), + dist_autograd.context(), + ): + pass @dist_init def test_graph_for_builtin_call(self): diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index 76c089f45800d..82a5d66e87f38 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -1021,11 +1021,13 @@ def test_rref_jit_pickle_not_supported(self): n = self.rank + 1 dst_rank = n % self.world_size rref_var = rpc_return_rref(worker_name(dst_rank)) - with TemporaryFileName() as fname: - with self.assertRaisesRegex( + with ( + TemporaryFileName() as fname, + self.assertRaisesRegex( RuntimeError, "RRef jit pickling is only allowed inside RPC calls" - ): - save_rref(rref_var, fname) + ), + ): + save_rref(rref_var, fname) @dist_init def test_remote_script_throw(self): @@ -1294,9 +1296,8 @@ def test_record_function_jit_end_callbacks_with_fork(self): def test_call_fork_in_jit_with_profiling(self): # Ensures that we can call torch.ops.profiler._call_end_callbacks_on_jit_fut on a jit # future from within a script function with torch.jit.fork - with _profile() as prof: - with torch.autograd.profiler.record_function("foo") as rf: - call_fork_with_profiling(rf.record) + with _profile() as prof, torch.autograd.profiler.record_function("foo") as rf: + call_fork_with_profiling(rf.record) events = prof.function_events function_event = get_function_event(events, "foo") diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 14d16281c14e2..c50aadc058cbd 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -2093,17 +2093,15 @@ def _run_test_profiler_with_autograd_context(self): dst = (self.rank + 1) % self.world_size if self.rank == 1: # Cases where we can double wrap messages with profiling information and autograd info. - with dist_autograd.context(): - with _profile() as prof: - self.run_profiling_workload(dst) + with dist_autograd.context(), _profile() as prof: + self.run_profiling_workload(dst) self.validate_profiling_workload(dst, prof) # Ensure that flipped order of ctx managers results in events being # recorded as expected. - with _profile() as prof: - with dist_autograd.context(): - self.run_profiling_workload(dst) + with _profile() as prof, dist_autograd.context(): + self.run_profiling_workload(dst) self.validate_profiling_workload(dst, prof) @@ -3518,28 +3516,25 @@ def test_wait_all_multiple_call(self): @dist_init def test_wait_all_timeout(self): expected_error = self.get_timeout_error_regex() - with self.assertRaisesRegex(RuntimeError, expected_error): - with _wait_all(): - self.assertTrue(_thread_local_var.future_list == []) - dst = worker_name((self.rank + 1) % self.world_size) - timeout = 0.1 # 100 ms - rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) + with self.assertRaisesRegex(RuntimeError, expected_error), _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + timeout = 0.1 # 100 ms + rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) self.assertFalse(hasattr(_thread_local_var, "future_list")) @dist_init def test_wait_all_raise_in_user_func(self): - with self.assertRaises(ValueError): - with _wait_all(): - self.assertTrue(_thread_local_var.future_list == []) - dst = worker_name((self.rank + 1) % self.world_size) - rpc.rpc_async(dst, raise_func) + with self.assertRaises(ValueError), _wait_all(): + self.assertTrue(_thread_local_var.future_list == []) + dst = worker_name((self.rank + 1) % self.world_size) + rpc.rpc_async(dst, raise_func) self.assertFalse(hasattr(_thread_local_var, "future_list")) @dist_init def test_wait_all_raise_in_body(self): - with self.assertRaises(ValueError): - with _wait_all(): - raise_func() + with self.assertRaises(ValueError), _wait_all(): + raise_func() self.assertFalse(hasattr(_thread_local_var, "future_list")) @dist_init @@ -3739,11 +3734,13 @@ def test_user_rrefs_confirmed_remote(self): @dist_init def test_rref_py_pickle_not_supported(self): local_rref = RRef(35) - with TemporaryFileName() as fname: - with self.assertRaisesRegex( + with ( + TemporaryFileName() as fname, + self.assertRaisesRegex( RuntimeError, "Can not pickle rref in python pickler" - ): - torch.save(local_rref, fname) + ), + ): + torch.save(local_rref, fname) @dist_init def test_remote_throw(self): @@ -3959,17 +3956,14 @@ def test_pickle_future(self): errMsg = "Can not pickle torch.futures.Future" dst = worker_name((self.rank + 1) % self.world_size) - with TemporaryFileName(): - with self.assertRaisesRegex(RuntimeError, errMsg): - rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) + with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) - with TemporaryFileName(): - with self.assertRaisesRegex(RuntimeError, errMsg): - rpc.rpc_async(dst, fail_on_fut, args=(fut,)) + with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg): + rpc.rpc_async(dst, fail_on_fut, args=(fut,)) - with TemporaryFileName(): - with self.assertRaisesRegex(RuntimeError, errMsg): - rpc.remote(dst, fail_on_fut, args=(fut,)) + with TemporaryFileName(), self.assertRaisesRegex(RuntimeError, errMsg): + rpc.remote(dst, fail_on_fut, args=(fut,)) @dist_init def test_future_done(self): diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index ce8e68ae1e2c5..7647a6595ec73 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -281,13 +281,13 @@ def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=No # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually - f = tempfile.NamedTemporaryFile(delete=False) - try: - f.close() - imported.save(f.name) - result = torch.jit.load(f.name, map_location=map_location) - finally: - os.unlink(f.name) + with tempfile.NamedTemporaryFile(delete=False) as f: + try: + f.close() + imported.save(f.name) + result = torch.jit.load(f.name, map_location=map_location) + finally: + os.unlink(f.name) result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) return result @@ -459,70 +459,69 @@ def checkScript(self, Checks that a given script generates the same output as the Python version using the given inputs. """ - with torch.jit.optimized_execution(optimize): - with enable_profiling_mode_for_profiling_tests(): - extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs) - if isinstance(script, str): - # Compile the string to a Script function - # with enable_profiling_mode(): - cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) - - # Execute the Python function so we can run it later and get its - # outputs - - frame = self.get_frame_vars(frames_up) - the_locals: dict[str, Any] = {} - execWrapper(script, glob=frame, loc=the_locals) - frame.update(the_locals) - - python_fn = frame[name] - scripted_fn = getattr(cu, name) - else: - - # Check the string frontend first - source = textwrap.dedent(inspect.getsource(script)) - self.checkScript( - source, - inputs, - script.__name__, - optimize=optimize, - inputs_requires_grad=inputs_requires_grad, - capture_output=capture_output, - profiling=profiling, - frames_up=2) - - # Continue checking the Python frontend - scripted_fn = torch.jit.script(script, _frames_up=1) - python_fn = script - - if inputs_requires_grad: - recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) - else: - recording_inputs = inputs - - if capture_output: - with self.capture_stdout() as script_stdout: - script_outputs = scripted_fn(*recording_inputs) - with self.capture_stdout(): - opt_script_outputs = scripted_fn(*recording_inputs) - with self.capture_stdout(): - python_outputs = python_fn(*inputs) - if not IS_WINDOWS: - self.assertExpected(script_stdout[0], subname='stdout') - self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol) - else: - # profiling run + with torch.jit.optimized_execution(optimize), enable_profiling_mode_for_profiling_tests(): + extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs) + if isinstance(script, str): + # Compile the string to a Script function + # with enable_profiling_mode(): + cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) + + # Execute the Python function so we can run it later and get its + # outputs + + frame = self.get_frame_vars(frames_up) + the_locals: dict[str, Any] = {} + execWrapper(script, glob=frame, loc=the_locals) + frame.update(the_locals) + + python_fn = frame[name] + scripted_fn = getattr(cu, name) + else: + + # Check the string frontend first + source = textwrap.dedent(inspect.getsource(script)) + self.checkScript( + source, + inputs, + script.__name__, + optimize=optimize, + inputs_requires_grad=inputs_requires_grad, + capture_output=capture_output, + profiling=profiling, + frames_up=2) + + # Continue checking the Python frontend + scripted_fn = torch.jit.script(script, _frames_up=1) + python_fn = script + + if inputs_requires_grad: + recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) + else: + recording_inputs = inputs + + if capture_output: + with self.capture_stdout() as script_stdout: script_outputs = scripted_fn(*recording_inputs) - if inputs_requires_grad or extra_profile_runs: - opt_script_outputs = scripted_fn(*recording_inputs) - # optimized run + with self.capture_stdout(): opt_script_outputs = scripted_fn(*recording_inputs) - if TEST_BAILOUTS: - self.checkBailouts(scripted_fn, inputs, opt_script_outputs) + with self.capture_stdout(): python_outputs = python_fn(*inputs) - self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol) - self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol) - return scripted_fn + if not IS_WINDOWS: + self.assertExpected(script_stdout[0], subname='stdout') + self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol) + else: + # profiling run + script_outputs = scripted_fn(*recording_inputs) + if inputs_requires_grad or extra_profile_runs: + opt_script_outputs = scripted_fn(*recording_inputs) + # optimized run + opt_script_outputs = scripted_fn(*recording_inputs) + if TEST_BAILOUTS: + self.checkBailouts(scripted_fn, inputs, opt_script_outputs) + python_outputs = python_fn(*inputs) + self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol) + self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol) + return scripted_fn def checkTrace(self, func, reference_tensors, input_tensors=None, drop=None, allow_unused=False, verbose=False, From 87646e5db4cd4ed2b6710f7b870e2a7e03b61e79 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 7 Nov 2025 16:18:25 -0800 Subject: [PATCH 252/651] [dynamo][ac] Return all intermediates as outputs for AC Hop (#167192) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167192 Approved by: https://github.com/zou3519 --- test/dynamo/test_activation_checkpointing.py | 108 +++++++++++++++++-- test/export/test_export.py | 22 ++-- torch/_dynamo/output_graph.py | 5 + torch/_dynamo/variables/builder.py | 19 +++- torch/_dynamo/variables/higher_order_ops.py | 72 ++++++++++++- torch/_dynamo/variables/tensor.py | 5 +- 6 files changed, 211 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index d6c0feac19ae1..6e1f45c166984 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -1,4 +1,6 @@ # Owner(s): ["module: dynamo"] +# flake8: noqa: B950 +# flake8: noqa: E731 import contextlib import copy import functools @@ -15,7 +17,11 @@ import torch.utils.checkpoint from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd -from torch._dynamo.testing import CompileCounterWithBackend +from torch._dynamo.testing import ( + AotEagerAndRecordGraphs, + CompileCounterWithBackend, + normalize_gm, +) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu @@ -1649,6 +1655,43 @@ def fn(x): self.assertEqual(opt_fn(x), fn(x)) + def test_return_same_element_twice(self): + def gn(x): + y = torch.sin(x) + return y, y + + def fn(x): + return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True) + + x = torch.randn(4, 4, requires_grad=True) + ref = fn(x) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None + getitem: "f32[4, 4]" = tag_activation_checkpoint[0] + getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None + return (getitem, getitem_1) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[4, 4]"): + y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None + return (y, y) +""", + ) + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) def test_nonlocal_mutation(self): counter = 0 @@ -1699,34 +1742,87 @@ def fn(x): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) - @unittest.expectedFailure @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) def test_nonlocal_list_mutation_hidden(self): def gn(x, z): + o = torch.matmul(x, x) @ x out = x.sin() z.append(out) - return torch.cos(torch.sin(torch.matmul(x, x) @ x)) + return torch.cos(torch.sin(o)), torch.sin(x) def fn(x): z = [] - out1 = torch.utils.checkpoint.checkpoint( + outs = torch.utils.checkpoint.checkpoint( gn, x, z, use_reentrant=False, ) + out1 = outs[0] + # Check that the extra output pytree handling is done properly + out2 = outs[-1] - return out1, z[0] + return out1 + out2, z[0] x = torch.randn(4, 4, requires_grad=True) ref = fn(x) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) res = opt_fn(x) self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + wrap_body_0 = self.wrap_body_0 + tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None + out1: "f32[4, 4]" = tag_activation_checkpoint[0] + out2: "f32[4, 4]" = tag_activation_checkpoint[1] + getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None + + add: "f32[4, 4]" = out1 + out2; out1 = out2 = None + return (add, getitem_4) + + class wrap_body_0(torch.nn.Module): + def forward(self, l_x_: "f32[4, 4]"): + matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_) + o: "f32[4, 4]" = matmul @ l_x_ + + out: "f32[4, 4]" = l_x_.sin() + + sin_1: "f32[4, 4]" = torch.sin(o) + child: "f32[4, 4]" = torch.cos(sin_1) + child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None + return (child, child_1, matmul, o, out, sin_1) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[4, 4]"): + mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1) + mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None + + sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1) + + sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None + cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None + sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1) + + add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None + return (add, sin, primals_1) +""", + ) + devices = ["cuda", "hpu"] instantiate_device_type_tests( diff --git a/test/export/test_export.py b/test/export/test_export.py index c7848eb3d69de..a2fd76e0e0ccc 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -1224,8 +1224,14 @@ def forward(self, x): %p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias] %x : [num_users=1] = placeholder[target=x] %wrap_body0 : [num_users=1] = get_attr[target=wrap_body0] - %tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {}) + %tag_activation_checkpoint : [num_users=7] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {}) + %getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 1), kwargs = {}) + %getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 2), kwargs = {}) + %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 3), kwargs = {}) + %getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 4), kwargs = {}) + %getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 5), kwargs = {}) + %getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 6), kwargs = {}) return (getitem,)""", ) @@ -1234,14 +1240,14 @@ def forward(self, x): """\ graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] - %arg1_1 : [num_users=1] = placeholder[target=arg1_1] - %arg2_1 : [num_users=1] = placeholder[target=arg2_1] - %arg3_1 : [num_users=1] = placeholder[target=arg3_1] - %arg4_1 : [num_users=1] = placeholder[target=arg4_1] - %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {}) - %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {}) + %arg1_1 : [num_users=2] = placeholder[target=arg1_1] + %arg2_1 : [num_users=2] = placeholder[target=arg2_1] + %arg3_1 : [num_users=2] = placeholder[target=arg3_1] + %arg4_1 : [num_users=2] = placeholder[target=arg4_1] + %linear : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {}) + %relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {}) %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {}) - return (linear_1,)""", + return (linear_1, arg1_1, arg2_1, linear, relu, arg3_1, arg4_1)""", ) stack = contextlib.ExitStack() diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index e45fa5f25138d..b5a4a69e3dceb 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2995,6 +2995,11 @@ def __init__( "Inference mode is supposed to be disabled during compilation. Please open an issue." ) + self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet() + + def record_tensor_or_symint_vt(self, vt): + self.tracked_tensor_or_symint_vt.add(vt) + # preserve original meta if it is available def _maybe_preserve_original_meta( self, tx: "InstructionTranslatorBase", node: fx.Node diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index aa1d0f04d2040..5d213118b4d22 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2784,21 +2784,34 @@ def wrap_fx_proxy_cls( target_cls, tx, proxy, example_value=None, subclass_type=None, **options ): if example_value is None: - return _wrap_fx_proxy( + out = _wrap_fx_proxy( target_cls, tx, proxy, example_value, subclass_type, **options ) elif isinstance(example_value, torch.Tensor): - return _wrap_fx_preexisting_tensor( + out = _wrap_fx_preexisting_tensor( target_cls, tx, proxy, example_value, subclass_type, **options ) else: # This will skip tracing an op and recursively reinvoke wrap_fx_proxy_cls on supported # data structures. In essence this just handles tracing some other value which may # contain Fake Tensors or is otherwise proxyable. - return handle_traced_output( + out = handle_traced_output( example_value, tx, proxy, options, subclass_type, target_cls ) + if ( + isinstance( + out, + ( + torch._dynamo.variables.TensorVariable, + torch._dynamo.variables.SymNodeVariable, + ), + ) + and proxy.node.op != "placeholder" + ): + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out + # This is 1 above (wrapping a preexisting tensor) def _wrap_fx_preexisting_tensor( diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 89c2a7451a771..4f2e6a921ca31 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -87,6 +87,8 @@ class OutputSpec: # that this is the same length as the mask, we just look at the indices # where mask is True. const_values: Optional[list[Any]] = None + # Number of intermediate nodes that are also made subgraph outputs. + num_intermediate_nodes_as_outputs: int = 0 def __post_init__(self): if ( @@ -278,6 +280,13 @@ def _call_function_and_unflatten_output( ) orig_vt.proxy = subgraph_vt.proxy + if ret_spec.num_intermediate_nodes_as_outputs: + # The treespec was computed w/o any extra intermediate outputs. At this + # point, it is safe to just get rid of the extra outputs + flat_variable = TupleVariable( + flat_variable.items[: -ret_spec.num_intermediate_nodes_as_outputs] + ) + if ret_spec.masks_to_filter_const_values: from torch._dynamo.external_utils import insert_const_values_with_mask @@ -1084,6 +1093,59 @@ def speculate_subgraph( output, masks_to_filter_const_values ) + # NOTE - [Return subgraph intermediates as subgraph outputs] + # This helps HOPs which allow side effects. Consider the + # following example + # + # def gn(x, z): + # o = torch.matmul(x, x) @ x + # out = x.sin() + # z.append(out) + # return torch.cos(torch.sin(o)) + + # def fn(x): + # z = [] + # out1 = torch.utils.checkpoint.checkpoint( + # gn, + # x, + # z, + # use_reentrant=False, + # ) + # return out1, z[0] + # + # In this example, list `z` is in outer scope and gets appended + # in the subgraph with `out`. But `out` is not an output of the + # subgraph. This can cause issue because later on when the outer + # graph returns `z[0]` it needs to have access to the graph node + # `out`. To solve this problem, we just return all intermediates + # from the subgraph. + + # TODO - Today this is supported only for AC. AC HOP gets + # desugared in AOTDispatcher so even though subgraph has extra + # unused outputs in Dynamo, its ok even if we don't DCE them in + # Dynamo. As AOTDispatcher desugars/inlines the subgraph, the + # subgraph boundary disappears. And even for AC, today this only + # works when the skip_fwd_side_effects_in_bwd_under_checkpoint + # flag is True, i.e., only when we allow side-effects. But, we + # want this to be supported for other Hops as well, specifically + # nested_compile_region and autograd.Function. Today, its safe + # because we error out on seeing a side-effect. + num_intermediate_nodes_as_outputs = 0 + if under_activation_checkpoint and should_flatten_outputs: + output_vts = { + vt + for vt in output.items + if isinstance( + vt, (variables.TensorVariable, variables.SymNodeVariable) + ) + } + extra_outputs = [] + for out in subtracer.tracked_tensor_or_symint_vt: + if out not in output_vts: + extra_outputs.append(out) + output = TupleVariable(output.items + extra_outputs) + num_intermediate_nodes_as_outputs = len(extra_outputs) + # Register output to graph # Modeled off of compile_and_call_fx_graph # TODO: support pytree output @@ -1095,7 +1157,10 @@ def speculate_subgraph( ( output, OutputSpec( - treespec, masks_to_filter_const_values, const_values + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, ), ), tx.output.graph, @@ -1219,7 +1284,10 @@ def move_lifted_freevars_phs_to_end( ( output, OutputSpec( - treespec, masks_to_filter_const_values, const_values + treespec, + masks_to_filter_const_values, + const_values, + num_intermediate_nodes_as_outputs, ), ), graph, diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index ca57d4e7e8783..326178ef00874 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1432,7 +1432,10 @@ def create(cls, tx, proxy, sym_num=None, **options): sym_num = int(sym_num) if isinstance(sym_num, sympy.Integer) else sym_num return ConstantVariable.create(sym_num) - return SymNodeVariable(proxy, sym_num, **options) + out = SymNodeVariable(proxy, sym_num, **options) + if proxy.node.op != "placeholder": + tx.output.current_tracer.record_tensor_or_symint_vt(out) + return out def __init__(self, proxy, sym_num, **kwargs) -> None: super().__init__(**kwargs) From 0b12e49795c647f87cac4091143c91643d9a5018 Mon Sep 17 00:00:00 2001 From: Benji Beck Date: Sat, 8 Nov 2025 05:13:38 +0000 Subject: [PATCH 253/651] [Inductor] Decouple flags for optimization and debug symbols (#167385) Summary: What: Decouple flags for optimization and debug symbols Why: The current flag for debug symbols only compiles the .so binary in unoptimized mode Differential Revision: D86363355 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167385 Approved by: https://github.com/hl475, https://github.com/jansel --- torch/_inductor/config.py | 1 + torch/_inductor/cpp_builder.py | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index fb43a9b859ffb..5ec1cecb251bf 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1603,6 +1603,7 @@ class aot_inductor: output_path = "" debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" + debug_symbols = os.environ.get("AOT_INDUCTOR_DEBUG_SYMBOLS", "0") == "1" # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl, # to use which cpp compiler optimization level, default to O1 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index 9b2444fb5ef19..e87600f27974d 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -881,23 +881,33 @@ def _get_optimization_cflags( cflags: list[str] = [] ldflags: list[str] = [] - b_debug_build = ( + should_use_optimized_flags = not ( config.aot_inductor.debug_compile or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" ) - wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level - - if b_debug_build: - cflags, ldflags = _get_inductor_debug_symbol_cflags() + should_add_debug_symbol_flags = ( + config.aot_inductor.debug_compile + or config.aot_inductor.debug_symbols + or os.environ.get("TORCHINDUCTOR_DEBUG_SYMBOL", "0") == "1" + ) + if should_use_optimized_flags: if _IS_WINDOWS: - cflags += ["Od", "Ob0", "Oy-"] + cflags += ["O1" if min_optimize else "O2"] else: - cflags.append("O0") + cflags += [ + config.aot_inductor.compile_wrapper_opt_level if min_optimize else "O3", + "DNDEBUG", + ] else: if _IS_WINDOWS: - cflags = ["O1" if min_optimize else "O2"] + cflags += ["Od", "Ob0", "Oy-"] else: - cflags = [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"] + cflags += ["O0"] + + if should_add_debug_symbol_flags: + debug_cflags, debug_ldflags = _get_inductor_debug_symbol_cflags() + cflags += debug_cflags + ldflags += debug_ldflags cflags += _get_ffast_math_flags() From eeb6c96a89b7934bfd2f18b9c44b00486165a2fb Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Sat, 8 Nov 2025 05:58:08 +0000 Subject: [PATCH 254/651] [vision hash update] update the pinned vision hash (#167391) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167391 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 1c6bf359618d5..92f446d1f4b07 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -ca2212438fdd8ce29b66999ed70ed54b0f9372d1 +ccb801b88af136454798b945175c4c87e636ac33 From 957570e4a331794e54177899872ceaba674df9fc Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 8 Nov 2025 07:35:38 -0800 Subject: [PATCH 255/651] [dynamo][guards] 1/N Guard selectively for DTensor (#165824) A few internal jobs are observing very high guard overhead for DTensor. Since we own DTensor, we can make those guards way faster. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165824 Approved by: https://github.com/Lucaskabela, https://github.com/bdhirsh --- .../tensor/test_dtensor_compile.py | 19 +++++ torch/_dynamo/guards.py | 13 ++++ torch/_dynamo/variables/builder.py | 73 +++++++++++++++---- torch/distributed/tensor/_api.py | 2 + 4 files changed, 93 insertions(+), 14 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index b82e9c97b57a8..ddba3150b05fb 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -464,6 +464,25 @@ def g(x): run(g, 64, 8) self.assertEqual(cnt.frame_count, 2) + def test_dtensor_requires_grad_recompile(self): + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + @torch.compile(backend=cnt, fullgraph=True) + def f(x): + y = x * x + return y.to_local() + + full_x = torch.randn(8, 8, requires_grad=False) + x = distribute_tensor(full_x, mesh, [Shard(0)]) + f(x) + + full_x = torch.randn(8, 8, requires_grad=True) + x = distribute_tensor(full_x, mesh, [Shard(0)]) + f(x) + + self.assertEqual(cnt.frame_count, 2) + def test_dtensor_attribute_access_on_intermediate(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 4e7d83357d88d..0f4d0d897b469 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2158,6 +2158,19 @@ def metadata_checker(x: Any) -> bool: metadata_checker, get_verbose_code_parts(global_name, guard) ) + def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None: + # Copied from DTensor __metadata_guard__ + # TODO - Consider moving this to C++ if stable + value = deepcopy(self.get(guard.name)) + + def guard_fn(x: Any) -> bool: + return x._check_equals(value, skip_shapes=True) + + code = f"__dtensor_spec_{id(guard_fn)}" + self.get_guard_manager(guard).add_lambda_guard( + guard_fn, get_verbose_code_parts(code, guard) + ) + def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None: ref = self.arg_ref(guard) val = self.get(guard.name) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 5d213118b4d22..7e00fce306393 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -2227,25 +2227,70 @@ def wrap_tensor(self, value: torch.Tensor): if isinstance(source, GradSource) and is_from_optimizer_source(source): guard_type = GuardBuilder.NOT_NONE_MATCH - self.install_guards( - functools.partial( - guard_type, - value=( - value - if isinstance(source, NumpyTensorSource) - else TensorWeakRef(value) - ), - ) + is_dtensor = torch.distributed.is_available() and isinstance( + value, torch.distributed.tensor.DTensor ) + if not is_dtensor: + # We guard on the _local_tensor and the _spec, and therefore we dont + # have to guard on the outer DTensor. + self.install_guards( + functools.partial( + guard_type, + value=( + value + if isinstance(source, NumpyTensorSource) + else TensorWeakRef(value) + ), + ) + ) # We install TYPE_MATCH guards for traceable wrapper subclass object, # and recursively install corresponding guard for each inner attribute. if is_traceable_wrapper_subclass(value): - self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) - self.install_guards(GuardBuilder.TYPE_MATCH) - install_guard( - SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) - ) + # Tensor subclass guards are very expensive because they are + # implemented in Python. Since DTensor is PyTorch-maintained class, + # we can skip a lot of these guards. + if is_dtensor: + self.install_guards(GuardBuilder.TYPE_MATCH) + + # The inner tensor name is always _local_tensor. If its not, we + # raise assertion to update the check accordingly. + inner_tensor_name = value.__tensor_flatten__()[0][0] + if inner_tensor_name != "_local_tensor": + raise RuntimeError( + "Expecting Dtensor inner tensor name to be _local_tensor" + ) + + # Now selectively guard on the flattening context + flattening_ctx = value.__tensor_flatten__()[1] + # This is supposed to be (self._spec, self.requires_grad) + if not ( + len(flattening_ctx) == 2 + and flattening_ctx[0] == value._spec + and flattening_ctx[1] == value.requires_grad + ): + # If not, raise an assertion to update to the new guards + raise RuntimeError( + "Expecting Dtensor flattening ctx to be _spec, requires_grad" + ) + # Guard on the dtensor spec + install_guard( + AttrSource(self.source, "_spec").make_guard( + GuardBuilder.DTENSOR_SPEC_MATCH + ) + ) + # Move this to C++ + install_guard( + AttrSource(self.source, "requires_grad").make_guard( + GuardBuilder.EQUALS_MATCH + ) + ) + else: + self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH) + self.install_guards(GuardBuilder.TYPE_MATCH) + install_guard( + SubclassAttrListSource(source).make_guard(GuardBuilder.EQUALS_MATCH) + ) attrs, _ = value.__tensor_flatten__() for attr in attrs: diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index de86d7923ae65..9324f42ed0f70 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -671,6 +671,8 @@ def __get_tensor_shard__(self, index): def __metadata_guard__( cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool] ) -> bool: + # TODO - delete this - This is now unused after the PR - + # https://github.com/pytorch/pytorch/pull/165824 orig_spec, orig_requires_grad = orig other_spec, other_requires_grad = other return ( From 406719c3daf84b4ecec98134ef3ad6ca953b86c4 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 8 Nov 2025 20:03:49 +0000 Subject: [PATCH 256/651] [MPS] SparseMps mv op (#166708) Should be merged after #166561 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166708 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/native_functions.yaml | 2 +- test/test_sparse.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 491521bdc9601..633d66f669b65 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4383,7 +4383,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: mv - SparseCPU, SparseCUDA: mv_sparse + SparseCPU, SparseCUDA, SparseMPS: mv_sparse - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) dispatch: diff --git a/test/test_sparse.py b/test/test_sparse.py index f1ed24667e133..11e1629e374ba 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -2674,7 +2674,6 @@ def test_asin_arcsin(self, device, dtype, coalesced): self._test_asin_arcsin(input_uncoalesced, coalesced) @coalescedonoff - @expectedFailureMPS @dtypes(torch.double) @dtypesIfMPS(torch.float32) def test_mv(self, device, dtype, coalesced): From 27ac58bd707c7d04cb90238ef2c9aa7a7e6fb281 Mon Sep 17 00:00:00 2001 From: Amin Sedaghat Date: Sat, 8 Nov 2025 20:59:41 +0000 Subject: [PATCH 257/651] Optimize global save-plan validation (#166820) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - Fixes #163548 by replacing the quadratic chunk-overlap scan in `_validate_global_plan` with a sweep-line pass that sorts chunk intervals and keeps an active set via `bisect_right`, giving O(n log n) behavior for metadata validation. - Add focused tests in `TestValidateGlobalPlan` covering overlapping and non-overlapping shard layouts to lock in the faster path. ## Testing - python test/distributed/checkpoint/test_planner.py -k ValidateGlobalPlan ## Benchmarks | chunks | old runtime | new runtime | |--------|-------------|-------------| | 1 024 | 0.121 s | 0.0014 s | | 2 048 | 0.486 s | 0.0027 s | | 4 096 | 2.474 s | 0.0058 s | | 8 192 | 8.014 s | 0.0126 s | | 16 384 | 32.740 s | 0.026 s | @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166820 Approved by: https://github.com/LucasLLC, https://github.com/Skylion007 --- test/distributed/checkpoint/test_planner.py | 28 ++++++++ .../distributed/checkpoint/default_planner.py | 67 ++++++++++++++----- 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/test/distributed/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py index 16f7089206e34..a8620c383f2f9 100644 --- a/test/distributed/checkpoint/test_planner.py +++ b/test/distributed/checkpoint/test_planner.py @@ -18,6 +18,7 @@ from torch.distributed.checkpoint.api import CheckpointException from torch.distributed.checkpoint.default_planner import ( _create_default_local_metadata, + _validate_global_plan, create_default_global_save_plan, create_default_local_load_plan, create_default_local_save_plan, @@ -28,6 +29,7 @@ from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, ChunkStorageMetadata, + Metadata, MetadataIndex, TensorProperties, TensorStorageMetadata, @@ -560,6 +562,32 @@ def create_data(rank): self.assertTrue(_compare_save_plans(plan2, plan2)) +class TestValidateGlobalPlan(TestCase): + def _make_metadata(self, chunks, size): + storage = TensorStorageMetadata( + properties=TensorProperties(dtype=torch.float32), + size=torch.Size(size), + chunks=chunks, + ) + return Metadata(state_dict_metadata={"param": storage}) + + def test_non_overlapping_chunks(self): + chunks = [ + ChunkStorageMetadata(offsets=torch.Size([i]), sizes=torch.Size([1])) + for i in range(4) + ] + metadata = self._make_metadata(chunks, [4]) + self.assertTrue(_validate_global_plan([SavePlan([])], metadata)) + + def test_detect_overlapping_chunks(self): + chunks = [ + ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([2])), + ChunkStorageMetadata(offsets=torch.Size([1]), sizes=torch.Size([2])), + ] + metadata = self._make_metadata(chunks, [4]) + self.assertFalse(_validate_global_plan([SavePlan([])], metadata)) + + class TestLoadPlanner(TestCase): @with_temp_dir def test_strict(self): diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py index 2f68e7f842264..716cb90a99653 100644 --- a/torch/distributed/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -4,8 +4,10 @@ import dataclasses import io import logging -import operator -from functools import reduce +import math +import sys +from bisect import bisect_right, insort +from collections import ChainMap from typing import Any, cast, Optional, Union import torch @@ -136,9 +138,12 @@ def _create_global_plan( global_plan, metadata = create_default_global_save_plan(deduped_plans) if self.flatten_state_dict: - merged_mappings = reduce( - lambda x, y: x | y, (p.planner_data for p in global_plan) - ) + # | does not work for Python 3.8 or older version. + # merged_mappings = reduce( + # lambda x, y: x | y, (p.planner_data for p in global_plan) + # ) + planner_data_dict = [p.planner_data for p in global_plan] + merged_mappings = dict(ChainMap(*planner_data_dict)) metadata = dataclasses.replace(metadata, planner_data=merged_mappings) if not _validate_global_plan(global_plan, metadata): @@ -630,10 +635,11 @@ def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bo continue if len(value.size) == 0: continue + chunks = value.chunks chunks_volume = 0 - for chunk_idx, chunk0 in enumerate(value.chunks): + for chunk in chunks: # Compute the volume - if not _check_box_bounds(value.size, chunk0): + if not _check_box_bounds(value.size, chunk): logger.warning( """ key:%s has out of bounds chunk: @@ -641,21 +647,46 @@ def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bo """, key, value.size, - chunk0, + chunk, ) all_good = False - chunks_volume += reduce(operator.mul, chunk0.sizes, 1) - - # Check for overlap - for chunk1 in value.chunks[chunk_idx + 1 :]: - if _check_box_overlap(chunk0, chunk1): - logger.warning( - "key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 - ) - all_good = False + chunks_volume += math.prod(chunk.sizes) + + if len(chunks) > 1: + dims = len(value.size) + sweep_dim = max(range(dims), default=0, key=lambda d: value.size[d]) + sorted_indices = sorted( + range(len(chunks)), + key=lambda idx: ( + chunks[idx].offsets[sweep_dim], + *(chunks[idx].offsets[d] for d in range(dims)), + ), + ) + active: list[tuple[int, int]] = [] + for idx in sorted_indices: + current = chunks[idx] + start = current.offsets[sweep_dim] + end = start + current.sizes[sweep_dim] + + cutoff = bisect_right(active, (start, sys.maxsize)) + if cutoff: + del active[:cutoff] + + for _, other_idx in active: + other = chunks[other_idx] + if _check_box_overlap(current, other): + logger.warning( + "key:%s has overlapping chunks: %s %s", + key, + current, + other, + ) + all_good = False + + insort(active, (end, idx)) # Check whether combined chunk cover the whole tensor - tensor_volume = reduce(operator.mul, value.size, 1) + tensor_volume = math.prod(value.size) if len(global_plan) > 1 and chunks_volume != tensor_volume: logger.warning( """ From e342a7509a2bb76dcdcd817c3eaec9a0a4fca7e2 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Fri, 7 Nov 2025 22:37:36 -0800 Subject: [PATCH 258/651] [pallas backend] add cpu backend and parametrize the tests (#167388) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167388 Approved by: https://github.com/jansel --- test/inductor/test_pallas.py | 134 ++++++++++++++---------------- torch/_inductor/codegen/common.py | 1 + torch/_inductor/codegen/pallas.py | 8 ++ torch/_inductor/config.py | 4 +- 4 files changed, 72 insertions(+), 75 deletions(-) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index 3ba84e8cd2b8c..8571321e872e6 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -52,9 +52,12 @@ def make_pallas(cls): return test_class -@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") -class PallasTests(TestCase): - """Basic tests for Pallas backend functionality.""" +class PallasTestsMixin: + """Basic tests for Pallas backend functionality (parameterized by DEVICE). Mixin only, not collected.""" + + def _compile(self, fn): + key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend" + return torch.compile(fn, backend="inductor", options={key: "pallas"}) def test_simple_add(self): """Test basic element-wise addition.""" @@ -62,12 +65,10 @@ def test_simple_add(self): def fn(a, b): return a + b - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - a = torch.randn(1024, device="cuda") - b = torch.randn(1024, device="cuda") + a = torch.randn(1024, device=self.DEVICE) + b = torch.randn(1024, device=self.DEVICE) result = compiled(a, b) expected = fn(a, b) self.assertEqual(result, expected) @@ -78,12 +79,10 @@ def test_simple_mul(self): def fn(a, b): return a * b - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - a = torch.randn(1024, device="cuda") - b = torch.randn(1024, device="cuda") + a = torch.randn(1024, device=self.DEVICE) + b = torch.randn(1024, device=self.DEVICE) result = compiled(a, b) expected = fn(a, b) self.assertEqual(result, expected) @@ -94,11 +93,9 @@ def test_sin(self): def fn(x): return torch.sin(x) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda") + x = torch.randn(1024, device=self.DEVICE) result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -109,12 +106,10 @@ def test_fused_ops(self): def fn(x, y): return x.sin() + y - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda") - y = torch.randn(1024, device="cuda") + x = torch.randn(1024, device=self.DEVICE) + y = torch.randn(1024, device=self.DEVICE) result = compiled(x, y) expected = fn(x, y) self.assertEqual(result, expected) @@ -125,11 +120,9 @@ def test_exp_log(self): def fn(x): return torch.log(torch.exp(x)) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda") + x = torch.randn(1024, device=self.DEVICE) result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -140,11 +133,9 @@ def test_sqrt(self): def fn(x): return torch.sqrt(x) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt + x = torch.randn(1024, device=self.DEVICE).abs() # Ensure positive for sqrt result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -155,11 +146,9 @@ def test_tanh(self): def fn(x): return torch.tanh(x) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda") + x = torch.randn(1024, device=self.DEVICE) result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -170,11 +159,9 @@ def test_abs_neg(self): def fn(x): return torch.abs(-x) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(1024, device="cuda") + x = torch.randn(1024, device=self.DEVICE) result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -185,12 +172,10 @@ def test_maximum_minimum(self): def fn(a, b): return torch.maximum(a, b) + torch.minimum(a, b) - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - a = torch.randn(1024, device="cuda") - b = torch.randn(1024, device="cuda") + a = torch.randn(1024, device=self.DEVICE) + b = torch.randn(1024, device=self.DEVICE) result = compiled(a, b) expected = fn(a, b) self.assertEqual(result, expected) @@ -228,15 +213,17 @@ def test_compile_options(self): @torch.compile( backend="inductor", - options={"cuda_backend": "pallas"}, + options={ + ("cuda_backend" if self.DEVICE == "cuda" else "cpu_backend"): "pallas" + }, ) def pallas_fn(a, b): return a.sin() + b.cos() _, (code,) = run_and_get_code( pallas_fn, - torch.randn(64, device="cuda"), - torch.randn(64, device="cuda"), + torch.randn(64, device=self.DEVICE), + torch.randn(64, device=self.DEVICE), ) # Verify Pallas-specific code generation self.assertIn("import jax", code) @@ -249,12 +236,10 @@ def test_2d_tensor(self): def fn(x, y): return x + y - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) - x = torch.randn(32, 32, device="cuda") - y = torch.randn(32, 32, device="cuda") + x = torch.randn(32, 32, device=self.DEVICE) + y = torch.randn(32, 32, device=self.DEVICE) result = compiled(x, y) expected = fn(x, y) self.assertEqual(result, expected) @@ -265,12 +250,10 @@ def test_different_shapes(self): def fn(x): return x * 2.0 - compiled = torch.compile( - fn, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(fn) for shape in [(64,), (128,), (256,), (1024,)]: - x = torch.randn(shape, device="cuda") + x = torch.randn(shape, device=self.DEVICE) result = compiled(x) expected = fn(x) self.assertEqual(result, expected) @@ -282,12 +265,10 @@ def test_contiguous_index_validation(self): def contiguous_add(a, b): return a + b - compiled = torch.compile( - contiguous_add, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(contiguous_add) - a = torch.randn(1024, device="cuda") - b = torch.randn(1024, device="cuda") + a = torch.randn(1024, device=self.DEVICE) + b = torch.randn(1024, device=self.DEVICE) result = compiled(a, b) expected = contiguous_add(a, b) self.assertEqual(result, expected) @@ -296,11 +277,9 @@ def contiguous_add(a, b): def contiguous_mul(x): return x * 2.0 - compiled = torch.compile( - contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(contiguous_mul) - x = torch.randn(128, 8, device="cuda") + x = torch.randn(128, 8, device=self.DEVICE) result = compiled(x) expected = contiguous_mul(x) self.assertEqual(result, expected) @@ -310,12 +289,10 @@ def contiguous_mul(x): def operate_on_tensor(x): return x.sin() - compiled = torch.compile( - operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"} - ) + compiled = self._compile(operate_on_tensor) # Create a transposed (non-contiguous) view - x = torch.randn(64, 32, device="cuda") + x = torch.randn(64, 32, device=self.DEVICE) x_t = x.t() # Non-contiguous view self.assertFalse(x_t.is_contiguous()) @@ -332,13 +309,24 @@ def operate_on_tensor(x): self.assertEqual(result, expected) +@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +class PallasTestsCUDA(PallasTestsMixin, TestCase): + DEVICE = "cuda" + + +@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +class PallasTestsCPU(PallasTestsMixin, TestCase): + DEVICE = "cpu" + + # Create test variants using the main test suite # Note: Only enable GPU tests since Pallas primarily targets GPU -if test_torchinductor.HAS_GPU and HAS_PALLAS: - # Uncomment these to run full test suite with Pallas backend - # make_pallas(test_torchinductor.SweepInputsGPUTest) - # make_pallas(test_torchinductor.GPUTests) - pass +if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS: + if getattr(test_torchinductor, "HAS_GPU", False): + # Uncomment these to run full test suite with Pallas backend + # make_pallas(test_torchinductor.SweepInputsGPUTest) + # make_pallas(test_torchinductor.GPUTests) + pass if __name__ == "__main__": if HAS_PALLAS: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3e9f174c810c5..8b5e68780cb28 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -521,6 +521,7 @@ def init_backend_registration() -> None: "cpp": CppScheduling, "halide": HalideScheduling, "triton": TritonScheduling, + "pallas": PallasScheduling, } register_backend_for_device( "cpu", diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 6ee901d19b14f..bce91dc31dc7c 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -311,6 +311,12 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove main_name = f"{kernel_name}_main" code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):") with code.indent(): + # Determine interpret statically based on codegen device + interpret_literal = ( + "True" + if V.graph.get_current_device_or_throw().type == "cpu" + else "False" + ) # Identify inputs (in_ptr*) and output (out_ptr*) input_params = [ p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr")) @@ -346,9 +352,11 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove ) # Call pallas + # Pass interpret=True on CPU, False otherwise (single call, no duplication) code.writeline("compiled = pl.pallas_call(") code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),") code.writeline(" out_shape=out_spec,") + code.writeline(f" interpret={interpret_literal},") code.writeline(" grid=(1,),") code.writeline(")") diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5ec1cecb251bf..bd0ff91616b37 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1973,8 +1973,8 @@ class rocm: contiguous_threshold: int = 16 -# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) -cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" +# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental) +cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp" # Backend to use for CUDA codegen either # "triton", "halide" (experimental) or "pallas" (experimental) From 71606b289ceadf9dec3a96e74ecadbce47816c77 Mon Sep 17 00:00:00 2001 From: linhaifeng <1371675203@qq.com> Date: Sat, 8 Nov 2025 23:57:12 +0000 Subject: [PATCH 259/651] [BugFix] Fix compute_error in coo_mean_time and csr_mean_time (#166795) The csr timing loop is nested inside the coo loop. duplicated and inconsistent measurements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166795 Approved by: https://github.com/cyyever, https://github.com/ezyang --- benchmarks/sparse/spmm.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/benchmarks/sparse/spmm.py b/benchmarks/sparse/spmm.py index b707556dd7a15..194d302a913b1 100644 --- a/benchmarks/sparse/spmm.py +++ b/benchmarks/sparse/spmm.py @@ -52,19 +52,18 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count): start.record() coo.matmul(mat) stop.record() - times.append(start.elapsed_time(stop)) - coo_mean_time = sum(times) / len(times) + coo_mean_time = sum(times) / len(times) - times = [] - for _ in range(test_count): - start.record() - csr.matmul(mat) - stop.record() - times.append(start.elapsed_time(stop)) + times = [] + for _ in range(test_count): + start.record() + csr.matmul(mat) + stop.record() + times.append(start.elapsed_time(stop)) - csr_mean_time = sum(times) / len(times) + csr_mean_time = sum(times) / len(times) return coo_mean_time, csr_mean_time From 47acdea74a9490adb6a70cb723d9c4109b48f184 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 9 Nov 2025 00:20:54 +0000 Subject: [PATCH 260/651] another version of fixing CachingHostAllocatorImpl destructor (#167408) Another version of #167347 that won't break xpu and should correctly handle runtime changes of `pinned_use_background_threads()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167408 Approved by: https://github.com/yingufan, https://github.com/Skylion007 --- aten/src/ATen/core/CachingHostAllocator.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 603e7e73bc1ea..71af40c5fd20a 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -226,8 +226,8 @@ template < typename B = HostBlock> struct CachingHostAllocatorImpl { virtual ~CachingHostAllocatorImpl() { - active_ = false; - if (pinned_use_background_threads()) { + if (active_) { + active_ = false; getBackgroundThreadPool()->waitWorkComplete(); } } @@ -260,6 +260,7 @@ struct CachingHostAllocatorImpl { if (pinned_use_background_threads()) { // Launch the background thread and process events in a loop. static bool background_thread_flag [[maybe_unused]] = [this] { + active_ = true; getBackgroundThreadPool()->run([&]() { while (active_) { process_events(); @@ -683,9 +684,9 @@ struct CachingHostAllocatorImpl { alignas(hardware_destructive_interference_size) std::mutex events_mutex_; std::deque> events_; // event queue paired with block - // Indicates whether the object is active. + // Indicates whether the event-processing thread pool is active. // Set to false in the destructor to signal background threads to stop. - std::atomic active_{true}; + std::atomic active_{false}; protected: alignas(hardware_destructive_interference_size) HostStatsStaged stats_; }; From 325ec9800916f590e8f71e0acb91e7722ab31842 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 9 Nov 2025 01:47:38 +0000 Subject: [PATCH 261/651] [13/N] Apply ruff UP035 rule (#167048) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048 Approved by: https://github.com/Skylion007 --- test/dynamo/test_install_free_tensors.py | 4 ++-- test/dynamo/test_python_autograd.py | 6 +++++- test/typing/pass/arithmetic_ops.py | 4 ++-- torch/_C/_distributed_c10d.pyi | 3 ++- torch/_dynamo/variables/ctx_manager.py | 4 ++-- torch/_inductor/codegen/pallas.py | 4 +++- torch/_inductor/runtime/caching/config.py | 2 +- torch/distributed/_local_tensor/_c10d.py | 3 +-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index 3858b827bd598..fd9e14c4c3f76 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import unittest -from collections.abc import Sequence -from typing import Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import torch import torch._dynamo diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index a615c653f56c3..a6117bb4093a7 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple, Optional, TYPE_CHECKING import torch import torch._dynamo @@ -7,6 +7,10 @@ from torch._dynamo.testing import CompileCounter, same +if TYPE_CHECKING: + from collections.abc import Callable + + """ This is an example of a pure-python version of autograd implemented by @zdevito. It represents a rather challenging test case for TorchDynamo diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index f0d6cc6fd9f97..14dda1cf39772 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,5 +1,5 @@ -from typing import Union -from typing_extensions import assert_type, TypeAlias +from typing import TypeAlias, Union +from typing_extensions import assert_type from torch import randn, Tensor diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f3d96860f5584..b659be9ee119e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" +from collections.abc import Callable from datetime import timedelta from enum import Enum -from typing import Any, Callable, Optional, overload, Union +from typing import Any, Optional, overload, Union import torch from torch import Tensor diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 59c4bd99e25b8..81bb0777b5555 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -21,9 +21,9 @@ import inspect import sys import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Sequence, Sized from contextlib import ExitStack -from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index bce91dc31dc7c..e5bf1fa17cdca 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -2,7 +2,7 @@ from __future__ import annotations import hashlib -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy # noqa: TC002 @@ -17,6 +17,8 @@ if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from ..ir import IRNode from ..scheduler import BaseSchedulerNode diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index 748715d1631ad..14e13f937dbb7 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -1,6 +1,6 @@ import os +from collections.abc import Callable from functools import cache, partial -from typing import Callable import torch from torch._environment import is_fbcode diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index c9256543e8977..0b63330dfafce 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -1,9 +1,8 @@ import functools import math import operator -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import timedelta -from typing import Callable import torch from torch._C import ScriptObject From 0384104e23bbd0b4cc2a9b92cffb733137c2f882 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Sun, 9 Nov 2025 02:14:30 +0000 Subject: [PATCH 262/651] Update pythoncapi_compat.h to 11cb80f2652cb2fe5231bf60b9dd98c83a4e25f4 (#167413) Second attempt for https://github.com/pytorch/pytorch/pull/167138 with fixes for name conflicts in downstream packages. Should slightly simplify https://github.com/pytorch/pytorch/pull/166342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167413 Approved by: https://github.com/Skylion007 --- torch/csrc/utils/python_compat.h | 8 - torch/csrc/utils/pythoncapi_compat.h | 1417 +++++++++++++++++++++++++- 2 files changed, 1405 insertions(+), 20 deletions(-) diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 16308dad4421d..8488d5d0917b5 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -33,14 +33,6 @@ static inline int PyCode_GetNFreevars(PyCodeObject* code) { #endif } -// Provided by CPython but getting the header for them is very hard -#if IS_PYTHON_3_11_PLUS -// NOLINTNEXTLINE(readability-redundant-declaration) -PyAPI_FUNC(void) _PyWeakref_ClearRef(PyWeakReference* self); -#else -extern void _PyWeakref_ClearRef(PyWeakReference* self); -#endif - #ifdef __cplusplus } #endif diff --git a/torch/csrc/utils/pythoncapi_compat.h b/torch/csrc/utils/pythoncapi_compat.h index 05e80b5ee8607..cdfdafa84eb72 100644 --- a/torch/csrc/utils/pythoncapi_compat.h +++ b/torch/csrc/utils/pythoncapi_compat.h @@ -7,7 +7,7 @@ // https://github.com/python/pythoncapi_compat // // Latest version: -// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h +// https://raw.githubusercontent.com/python/pythoncapi-compat/main/pythoncapi_compat.h // // SPDX-License-Identifier: 0BSD @@ -19,6 +19,7 @@ extern "C" { #endif #include +#include // offsetof() // Python 3.11.0b4 added PyFrame_Back() to Python.h #if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) @@ -33,11 +34,13 @@ extern "C" { // Static inline functions should use _Py_NULL rather than using directly NULL // to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, // _Py_NULL is defined as nullptr. -#if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ - || (defined(__cplusplus) && __cplusplus >= 201103) -# define _Py_NULL nullptr -#else -# define _Py_NULL NULL +#ifndef _Py_NULL +# if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ + || (defined(__cplusplus) && __cplusplus >= 201103) +# define _Py_NULL nullptr +# else +# define _Py_NULL NULL +# endif #endif // Cast argument to PyObject* type. @@ -45,6 +48,13 @@ extern "C" { # define _PyObject_CAST(op) _Py_CAST(PyObject*, op) #endif +#ifndef Py_BUILD_ASSERT +# define Py_BUILD_ASSERT(cond) \ + do { \ + (void)sizeof(char [1 - 2 * !(cond)]); \ + } while(0) +#endif + // bpo-42262 added Py_NewRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) @@ -68,6 +78,16 @@ static inline PyObject* _Py_XNewRef(PyObject *obj) #endif +// bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) +static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) +{ + ob->ob_refcnt = refcnt; +} +#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) +#endif + + // Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. // It is excluded from the limited C API. #if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) @@ -104,6 +124,37 @@ static inline PyObject* _Py_XNewRef(PyObject *obj) # define Py_IsFalse(x) Py_Is(x, Py_False) #endif + +// bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) +static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) +{ + ob->ob_type = type; +} +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) +static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) +{ + ob->ob_size = size; +} +#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) +#endif + + +// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) +static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + assert(frame->f_code != _Py_NULL); + return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); +} +#endif + static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) { PyCodeObject *code = PyFrame_GetCode(frame); @@ -112,6 +163,15 @@ static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) } +// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); +} +#endif + #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) { @@ -229,6 +289,26 @@ PyFrame_GetVarString(PyFrameObject *frame, const char *name) #endif +// bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || (defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030B0000) +static inline PyInterpreterState * +PyThreadState_GetInterpreter(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->interp; +} +#endif + + +// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} +#endif + #if !defined(PYPY_VERSION) static inline PyFrameObject* _PyThreadState_GetFrameBorrow(PyThreadState *tstate) @@ -240,6 +320,35 @@ _PyThreadState_GetFrameBorrow(PyThreadState *tstate) #endif +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState* PyInterpreterState_Get(void) +{ + PyThreadState *tstate; + PyInterpreterState *interp; + + tstate = PyThreadState_GET(); + if (tstate == _Py_NULL) { + Py_FatalError("GIL released (tstate is NULL)"); + } + interp = tstate->interp; + if (interp == _Py_NULL) { + Py_FatalError("no current interpreter"); + } + return interp; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 +#if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->id; +} +#endif + // bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 #if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) static inline void PyThreadState_EnterTracing(PyThreadState *tstate) @@ -269,6 +378,27 @@ static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) #endif +// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 +// PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 +static inline PyObject* PyObject_CallNoArgs(PyObject *func) +{ + return PyObject_CallFunctionObjArgs(func, NULL); +} +#endif + + +// bpo-39245 made PyObject_CallOneArg() public (previously called +// _PyObject_CallOneArg) in Python 3.9.0a4 +// PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 +static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) +{ + return PyObject_CallFunctionObjArgs(func, arg, NULL); +} +#endif + + // bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 #if PY_VERSION_HEX < 0x030A00A3 static inline int @@ -294,6 +424,58 @@ PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) #endif +// bpo-40024 added PyModule_AddType() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 +static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) +{ + const char *name, *dot; + + if (PyType_Ready(type) < 0) { + return -1; + } + + // inline _PyType_Name() + name = type->tp_name; + assert(name != _Py_NULL); + dot = strrchr(name, '.'); + if (dot != _Py_NULL) { + name = dot + 1; + } + + return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); +} +#endif + + +// bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. +// bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. +#if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsTracked(PyObject* obj) +{ + return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); +} +#endif + +// bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. +// bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. +#if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsFinalized(PyObject *obj) +{ + PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; + return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); +} +#endif + + +// bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) +static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { + return Py_TYPE(ob) == type; +} +#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) +#endif + + // bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. // bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. // Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal @@ -401,7 +583,7 @@ static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) return 0; } *pobj = Py_NewRef(obj); - return (*pobj != NULL); + return 1; } #endif @@ -420,6 +602,81 @@ static inline Py_ssize_t PyVectorcall_NARGS(size_t n) #endif +// gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 +static inline PyObject* +PyObject_Vectorcall(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames) +{ +#if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) + // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 + return _PyObject_Vectorcall(callable, args, nargsf, kwnames); +#else + PyObject *posargs = NULL, *kwargs = NULL; + PyObject *res; + Py_ssize_t nposargs, nkwargs, i; + + if (nargsf != 0 && args == NULL) { + PyErr_BadInternalCall(); + goto error; + } + if (kwnames != NULL && !PyTuple_Check(kwnames)) { + PyErr_BadInternalCall(); + goto error; + } + + nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); + if (kwnames) { + nkwargs = PyTuple_GET_SIZE(kwnames); + } + else { + nkwargs = 0; + } + + posargs = PyTuple_New(nposargs); + if (posargs == NULL) { + goto error; + } + if (nposargs) { + for (i=0; i < nposargs; i++) { + PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); + args++; + } + } + + if (nkwargs) { + kwargs = PyDict_New(); + if (kwargs == NULL) { + goto error; + } + + for (i = 0; i < nkwargs; i++) { + PyObject *key = PyTuple_GET_ITEM(kwnames, i); + PyObject *value = *args; + args++; + if (PyDict_SetItem(kwargs, key, value) < 0) { + goto error; + } + } + } + else { + kwargs = NULL; + } + + res = PyObject_Call(callable, posargs, kwargs); + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return res; + +error: + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return NULL; +#endif +} +#endif + + // gh-106521 added PyObject_GetOptionalAttr() and // PyObject_GetOptionalAttrString() to Python 3.13.0a1 #if PY_VERSION_HEX < 0x030D00A1 @@ -664,7 +921,7 @@ static inline int PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (*dict == NULL) { + if (dict == NULL || *dict == NULL) { return -1; } Py_VISIT(*dict); @@ -675,7 +932,7 @@ static inline void PyObject_ClearManagedDict(PyObject *obj) { PyObject **dict = _PyObject_GetDictPtr(obj); - if (*dict == NULL) { + if (dict == NULL || *dict == NULL) { return; } Py_CLEAR(*dict); @@ -950,11 +1207,11 @@ static inline int PyTime_PerfCounter(PyTime_t *result) #endif // gh-111389 added hash constants to Python 3.13.0a5. These constants were -// added first as private macros to Python 3.4.0b1 and PyPy 7.3.9. +// added first as private macros to Python 3.4.0b1 and PyPy 7.3.8. #if (!defined(PyHASH_BITS) \ && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ - && PYPY_VERSION_NUM >= 0x07090000))) + && PYPY_VERSION_NUM >= 0x07030800))) # define PyHASH_BITS _PyHASH_BITS # define PyHASH_MODULUS _PyHASH_MODULUS # define PyHASH_INF _PyHASH_INF @@ -1196,6 +1453,18 @@ PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, return res; } +static inline int +PyUnicodeWriter_WriteASCII(PyUnicodeWriter *writer, + const char *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)strlen(str); + } + + return _PyUnicodeWriter_WriteASCIIString((_PyUnicodeWriter*)writer, + str, size); +} + static inline int PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, const wchar_t *str, Py_ssize_t size) @@ -1219,7 +1488,8 @@ PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, Py_ssize_t start, Py_ssize_t end) { if (!PyUnicode_Check(str)) { - PyErr_Format(PyExc_TypeError, "expect str, not %T", str); + PyErr_Format(PyExc_TypeError, "expect str, not %s", + Py_TYPE(str)->tp_name); return -1; } if (start < 0 || start > end) { @@ -1266,6 +1536,1129 @@ static inline int PyLong_GetSign(PyObject *obj, int *sign) } #endif +// gh-126061 added PyLong_IsPositive/Negative/Zero() to Python in 3.14.0a2 +#if PY_VERSION_HEX < 0x030E00A2 +static inline int PyLong_IsPositive(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == 1; +} + +static inline int PyLong_IsNegative(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == -1; +} + +static inline int PyLong_IsZero(PyObject *obj) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expected int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + return _PyLong_Sign(obj) == 0; +} +#endif + + +// gh-124502 added PyUnicode_Equal() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyUnicode_Equal(PyObject *str1, PyObject *str2) +{ + if (!PyUnicode_Check(str1)) { + PyErr_Format(PyExc_TypeError, "first argument must be str, not %s", + Py_TYPE(str1)->tp_name); + return -1; + } + if (!PyUnicode_Check(str2)) { + PyErr_Format(PyExc_TypeError, "second argument must be str, not %s", + Py_TYPE(str2)->tp_name); + return -1; + } + +#if PY_VERSION_HEX >= 0x030d0000 && !defined(PYPY_VERSION) + PyAPI_FUNC(int) _PyUnicode_Equal(PyObject *str1, PyObject *str2); + + return _PyUnicode_Equal(str1, str2); +#elif PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) + return _PyUnicode_EQ(str1, str2); +#elif PY_VERSION_HEX >= 0x03090000 && defined(PYPY_VERSION) + return _PyUnicode_EQ(str1, str2); +#else + return (PyUnicode_Compare(str1, str2) == 0); +#endif +} +#endif + + +// gh-121645 added PyBytes_Join() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline PyObject* PyBytes_Join(PyObject *sep, PyObject *iterable) +{ + return _PyBytes_Join(sep, iterable); +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline Py_hash_t Py_HashBuffer(const void *ptr, Py_ssize_t len) +{ +#if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) + PyAPI_FUNC(Py_hash_t) _Py_HashBytes(const void *src, Py_ssize_t len); + + return _Py_HashBytes(ptr, len); +#else + Py_hash_t hash; + PyObject *bytes = PyBytes_FromStringAndSize((const char*)ptr, len); + if (bytes == NULL) { + return -1; + } + hash = PyObject_Hash(bytes); + Py_DECREF(bytes); + return hash; +#endif +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyIter_NextItem(PyObject *iter, PyObject **item) +{ + iternextfunc tp_iternext; + + assert(iter != NULL); + assert(item != NULL); + + tp_iternext = Py_TYPE(iter)->tp_iternext; + if (tp_iternext == NULL) { + *item = NULL; + PyErr_Format(PyExc_TypeError, "expected an iterator, got '%s'", + Py_TYPE(iter)->tp_name); + return -1; + } + + if ((*item = tp_iternext(iter))) { + return 1; + } + if (!PyErr_Occurred()) { + return 0; + } + if (PyErr_ExceptionMatches(PyExc_StopIteration)) { + PyErr_Clear(); + return 0; + } + return -1; +} +#endif + + +#if PY_VERSION_HEX < 0x030E00A0 +static inline PyObject* PyLong_FromInt32(int32_t value) +{ + Py_BUILD_ASSERT(sizeof(long) >= 4); + return PyLong_FromLong(value); +} + +static inline PyObject* PyLong_FromInt64(int64_t value) +{ + Py_BUILD_ASSERT(sizeof(long long) >= 8); + return PyLong_FromLongLong(value); +} + +static inline PyObject* PyLong_FromUInt32(uint32_t value) +{ + Py_BUILD_ASSERT(sizeof(unsigned long) >= 4); + return PyLong_FromUnsignedLong(value); +} + +static inline PyObject* PyLong_FromUInt64(uint64_t value) +{ + Py_BUILD_ASSERT(sizeof(unsigned long long) >= 8); + return PyLong_FromUnsignedLongLong(value); +} + +static inline int PyLong_AsInt32(PyObject *obj, int32_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(int) == 4); + int value = PyLong_AsInt(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (int32_t)value; + return 0; +} + +static inline int PyLong_AsInt64(PyObject *obj, int64_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long long) == 8); + long long value = PyLong_AsLongLong(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (int64_t)value; + return 0; +} + +static inline int PyLong_AsUInt32(PyObject *obj, uint32_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long) >= 4); + unsigned long value = PyLong_AsUnsignedLong(obj); + if (value == (unsigned long)-1 && PyErr_Occurred()) { + return -1; + } +#if SIZEOF_LONG > 4 + if ((unsigned long)UINT32_MAX < value) { + PyErr_SetString(PyExc_OverflowError, + "Python int too large to convert to C uint32_t"); + return -1; + } +#endif + *pvalue = (uint32_t)value; + return 0; +} + +static inline int PyLong_AsUInt64(PyObject *obj, uint64_t *pvalue) +{ + Py_BUILD_ASSERT(sizeof(long long) == 8); + unsigned long long value = PyLong_AsUnsignedLongLong(obj); + if (value == (unsigned long long)-1 && PyErr_Occurred()) { + return -1; + } + *pvalue = (uint64_t)value; + return 0; +} +#endif + + +// gh-102471 added import and export API for integers to 3.14.0a2. +#if PY_VERSION_HEX < 0x030E00A2 && PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) +// Helpers to access PyLongObject internals. +static inline void +_PyLong_SetSignAndDigitCount(PyLongObject *op, int sign, Py_ssize_t size) +{ +#if PY_VERSION_HEX >= 0x030C0000 + op->long_value.lv_tag = (uintptr_t)(1 - sign) | ((uintptr_t)(size) << 3); +#elif PY_VERSION_HEX >= 0x030900A4 + Py_SET_SIZE(op, sign * size); +#else + Py_SIZE(op) = sign * size; +#endif +} + +static inline Py_ssize_t +_PyLong_DigitCount(const PyLongObject *op) +{ +#if PY_VERSION_HEX >= 0x030C0000 + return (Py_ssize_t)(op->long_value.lv_tag >> 3); +#else + return _PyLong_Sign((PyObject*)op) < 0 ? -Py_SIZE(op) : Py_SIZE(op); +#endif +} + +static inline digit* +_PyLong_GetDigits(const PyLongObject *op) +{ +#if PY_VERSION_HEX >= 0x030C0000 + return (digit*)(op->long_value.ob_digit); +#else + return (digit*)(op->ob_digit); +#endif +} + +typedef struct PyLongLayout { + uint8_t bits_per_digit; + uint8_t digit_size; + int8_t digits_order; + int8_t digit_endianness; +} PyLongLayout; + +typedef struct PyLongExport { + int64_t value; + uint8_t negative; + Py_ssize_t ndigits; + const void *digits; + Py_uintptr_t _reserved; +} PyLongExport; + +typedef struct PyLongWriter PyLongWriter; + +static inline const PyLongLayout* +PyLong_GetNativeLayout(void) +{ + static const PyLongLayout PyLong_LAYOUT = { + PyLong_SHIFT, + sizeof(digit), + -1, // least significant first + PY_LITTLE_ENDIAN ? -1 : 1, + }; + + return &PyLong_LAYOUT; +} + +static inline int +PyLong_Export(PyObject *obj, PyLongExport *export_long) +{ + if (!PyLong_Check(obj)) { + memset(export_long, 0, sizeof(*export_long)); + PyErr_Format(PyExc_TypeError, "expected int, got %s", + Py_TYPE(obj)->tp_name); + return -1; + } + + // Fast-path: try to convert to a int64_t + PyLongObject *self = (PyLongObject*)obj; + int overflow; +#if SIZEOF_LONG == 8 + long value = PyLong_AsLongAndOverflow(obj, &overflow); +#else + // Windows has 32-bit long, so use 64-bit long long instead + long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); +#endif + Py_BUILD_ASSERT(sizeof(value) == sizeof(int64_t)); + // the function cannot fail since obj is a PyLongObject + assert(!(value == -1 && PyErr_Occurred())); + + if (!overflow) { + export_long->value = value; + export_long->negative = 0; + export_long->ndigits = 0; + export_long->digits = 0; + export_long->_reserved = 0; + } + else { + export_long->value = 0; + export_long->negative = _PyLong_Sign(obj) < 0; + export_long->ndigits = _PyLong_DigitCount(self); + if (export_long->ndigits == 0) { + export_long->ndigits = 1; + } + export_long->digits = _PyLong_GetDigits(self); + export_long->_reserved = (Py_uintptr_t)Py_NewRef(obj); + } + return 0; +} + +static inline void +PyLong_FreeExport(PyLongExport *export_long) +{ + PyObject *obj = (PyObject*)export_long->_reserved; + + if (obj) { + export_long->_reserved = 0; + Py_DECREF(obj); + } +} + +static inline PyLongWriter* +PyLongWriter_Create(int negative, Py_ssize_t ndigits, void **digits) +{ + if (ndigits <= 0) { + PyErr_SetString(PyExc_ValueError, "ndigits must be positive"); + return NULL; + } + assert(digits != NULL); + + PyLongObject *obj = _PyLong_New(ndigits); + if (obj == NULL) { + return NULL; + } + _PyLong_SetSignAndDigitCount(obj, negative?-1:1, ndigits); + + *digits = _PyLong_GetDigits(obj); + return (PyLongWriter*)obj; +} + +static inline void +PyLongWriter_Discard(PyLongWriter *writer) +{ + PyLongObject *obj = (PyLongObject *)writer; + + assert(Py_REFCNT(obj) == 1); + Py_DECREF(obj); +} + +static inline PyObject* +PyLongWriter_Finish(PyLongWriter *writer) +{ + PyObject *obj = (PyObject *)writer; + PyLongObject *self = (PyLongObject*)obj; + Py_ssize_t j = _PyLong_DigitCount(self); + Py_ssize_t i = j; + int sign = _PyLong_Sign(obj); + + assert(Py_REFCNT(obj) == 1); + + // Normalize and get singleton if possible + while (i > 0 && _PyLong_GetDigits(self)[i-1] == 0) { + --i; + } + if (i != j) { + if (i == 0) { + sign = 0; + } + _PyLong_SetSignAndDigitCount(self, sign, i); + } + if (i <= 1) { + long val = sign * (long)(_PyLong_GetDigits(self)[0]); + Py_DECREF(obj); + return PyLong_FromLong(val); + } + + return obj; +} +#endif + + +#if PY_VERSION_HEX < 0x030C00A3 +# define Py_T_SHORT 0 +# define Py_T_INT 1 +# define Py_T_LONG 2 +# define Py_T_FLOAT 3 +# define Py_T_DOUBLE 4 +# define Py_T_STRING 5 +# define _Py_T_OBJECT 6 +# define Py_T_CHAR 7 +# define Py_T_BYTE 8 +# define Py_T_UBYTE 9 +# define Py_T_USHORT 10 +# define Py_T_UINT 11 +# define Py_T_ULONG 12 +# define Py_T_STRING_INPLACE 13 +# define Py_T_BOOL 14 +# define Py_T_OBJECT_EX 16 +# define Py_T_LONGLONG 17 +# define Py_T_ULONGLONG 18 +# define Py_T_PYSSIZET 19 + +# if PY_VERSION_HEX >= 0x03000000 && !defined(PYPY_VERSION) +# define _Py_T_NONE 20 +# endif + +# define Py_READONLY 1 +# define Py_AUDIT_READ 2 +# define _Py_WRITE_RESTRICTED 4 +#endif + + +// gh-127350 added Py_fopen() and Py_fclose() to Python 3.14a4 +#if PY_VERSION_HEX < 0x030E00A4 +static inline FILE* Py_fopen(PyObject *path, const char *mode) +{ +#if 0x030400A2 <= PY_VERSION_HEX && !defined(PYPY_VERSION) + PyAPI_FUNC(FILE*) _Py_fopen_obj(PyObject *path, const char *mode); + + return _Py_fopen_obj(path, mode); +#else + FILE *f; + PyObject *bytes; +#if PY_VERSION_HEX >= 0x03000000 + if (!PyUnicode_FSConverter(path, &bytes)) { + return NULL; + } +#else + if (!PyString_Check(path)) { + PyErr_SetString(PyExc_TypeError, "except str"); + return NULL; + } + bytes = Py_NewRef(path); +#endif + const char *path_bytes = PyBytes_AS_STRING(bytes); + + f = fopen(path_bytes, mode); + Py_DECREF(bytes); + + if (f == NULL) { + PyErr_SetFromErrnoWithFilenameObject(PyExc_OSError, path); + return NULL; + } + return f; +#endif +} + +static inline int Py_fclose(FILE *file) +{ + return fclose(file); +} +#endif + + +#if 0x03080000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030E0000 && !defined(PYPY_VERSION) +PyAPI_FUNC(const PyConfig*) _Py_GetConfig(void); + +static inline PyObject* +PyConfig_Get(const char *name) +{ + typedef enum { + _PyConfig_MEMBER_INT, + _PyConfig_MEMBER_UINT, + _PyConfig_MEMBER_ULONG, + _PyConfig_MEMBER_BOOL, + _PyConfig_MEMBER_WSTR, + _PyConfig_MEMBER_WSTR_OPT, + _PyConfig_MEMBER_WSTR_LIST, + } PyConfigMemberType; + + typedef struct { + const char *name; + size_t offset; + PyConfigMemberType type; + const char *sys_attr; + } PyConfigSpec; + +#define PYTHONCAPI_COMPAT_SPEC(MEMBER, TYPE, sys_attr) \ + {#MEMBER, offsetof(PyConfig, MEMBER), \ + _PyConfig_MEMBER_##TYPE, sys_attr} + + static const PyConfigSpec config_spec[] = { + PYTHONCAPI_COMPAT_SPEC(argv, WSTR_LIST, "argv"), + PYTHONCAPI_COMPAT_SPEC(base_exec_prefix, WSTR_OPT, "base_exec_prefix"), + PYTHONCAPI_COMPAT_SPEC(base_executable, WSTR_OPT, "_base_executable"), + PYTHONCAPI_COMPAT_SPEC(base_prefix, WSTR_OPT, "base_prefix"), + PYTHONCAPI_COMPAT_SPEC(bytes_warning, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(exec_prefix, WSTR_OPT, "exec_prefix"), + PYTHONCAPI_COMPAT_SPEC(executable, WSTR_OPT, "executable"), + PYTHONCAPI_COMPAT_SPEC(inspect, BOOL, _Py_NULL), +#if 0x030C0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(int_max_str_digits, UINT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(interactive, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(module_search_paths, WSTR_LIST, "path"), + PYTHONCAPI_COMPAT_SPEC(optimization_level, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(parser_debug, BOOL, _Py_NULL), +#if 0x03090000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(platlibdir, WSTR, "platlibdir"), +#endif + PYTHONCAPI_COMPAT_SPEC(prefix, WSTR_OPT, "prefix"), + PYTHONCAPI_COMPAT_SPEC(pycache_prefix, WSTR_OPT, "pycache_prefix"), + PYTHONCAPI_COMPAT_SPEC(quiet, BOOL, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(stdlib_dir, WSTR_OPT, "_stdlib_dir"), +#endif + PYTHONCAPI_COMPAT_SPEC(use_environment, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(verbose, UINT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(warnoptions, WSTR_LIST, "warnoptions"), + PYTHONCAPI_COMPAT_SPEC(write_bytecode, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(xoptions, WSTR_LIST, "_xoptions"), + PYTHONCAPI_COMPAT_SPEC(buffered_stdio, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(check_hash_pycs_mode, WSTR, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(code_debug_ranges, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(configure_c_stdio, BOOL, _Py_NULL), +#if 0x030D0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(cpu_count, INT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(dev_mode, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(dump_refs, BOOL, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(dump_refs_file, WSTR_OPT, _Py_NULL), +#endif +#ifdef Py_GIL_DISABLED + PYTHONCAPI_COMPAT_SPEC(enable_gil, INT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(faulthandler, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(filesystem_encoding, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(filesystem_errors, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(hash_seed, ULONG, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(home, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(import_time, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(install_signal_handlers, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(isolated, BOOL, _Py_NULL), +#ifdef MS_WINDOWS + PYTHONCAPI_COMPAT_SPEC(legacy_windows_stdio, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(malloc_stats, BOOL, _Py_NULL), +#if 0x030A0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(orig_argv, WSTR_LIST, "orig_argv"), +#endif + PYTHONCAPI_COMPAT_SPEC(parse_argv, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(pathconfig_warnings, BOOL, _Py_NULL), +#if 0x030C0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(perf_profiling, UINT, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(program_name, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_command, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_filename, WSTR_OPT, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(run_module, WSTR_OPT, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(safe_path, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(show_ref_count, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(site_import, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(skip_source_first_line, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(stdio_encoding, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(stdio_errors, WSTR, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(tracemalloc, UINT, _Py_NULL), +#if 0x030B0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(use_frozen_modules, BOOL, _Py_NULL), +#endif + PYTHONCAPI_COMPAT_SPEC(use_hash_seed, BOOL, _Py_NULL), + PYTHONCAPI_COMPAT_SPEC(user_site_directory, BOOL, _Py_NULL), +#if 0x030A0000 <= PY_VERSION_HEX + PYTHONCAPI_COMPAT_SPEC(warn_default_encoding, BOOL, _Py_NULL), +#endif + }; + +#undef PYTHONCAPI_COMPAT_SPEC + + const PyConfigSpec *spec; + int found = 0; + for (size_t i=0; i < sizeof(config_spec) / sizeof(config_spec[0]); i++) { + spec = &config_spec[i]; + if (strcmp(spec->name, name) == 0) { + found = 1; + break; + } + } + if (found) { + if (spec->sys_attr != NULL) { + PyObject *value = PySys_GetObject(spec->sys_attr); + if (value == NULL) { + PyErr_Format(PyExc_RuntimeError, "lost sys.%s", spec->sys_attr); + return NULL; + } + return Py_NewRef(value); + } + + const PyConfig *config = _Py_GetConfig(); + void *member = (char *)config + spec->offset; + switch (spec->type) { + case _PyConfig_MEMBER_INT: + case _PyConfig_MEMBER_UINT: + { + int value = *(int *)member; + return PyLong_FromLong(value); + } + case _PyConfig_MEMBER_BOOL: + { + int value = *(int *)member; + return PyBool_FromLong(value != 0); + } + case _PyConfig_MEMBER_ULONG: + { + unsigned long value = *(unsigned long *)member; + return PyLong_FromUnsignedLong(value); + } + case _PyConfig_MEMBER_WSTR: + case _PyConfig_MEMBER_WSTR_OPT: + { + wchar_t *wstr = *(wchar_t **)member; + if (wstr != NULL) { + return PyUnicode_FromWideChar(wstr, -1); + } + else { + return Py_NewRef(Py_None); + } + } + case _PyConfig_MEMBER_WSTR_LIST: + { + const PyWideStringList *list = (const PyWideStringList *)member; + PyObject *tuple = PyTuple_New(list->length); + if (tuple == NULL) { + return NULL; + } + + for (Py_ssize_t i = 0; i < list->length; i++) { + PyObject *item = PyUnicode_FromWideChar(list->items[i], -1); + if (item == NULL) { + Py_DECREF(tuple); + return NULL; + } + PyTuple_SET_ITEM(tuple, i, item); + } + return tuple; + } + default: + Py_UNREACHABLE(); + } + } + + PyErr_Format(PyExc_ValueError, "unknown config option name: %s", name); + return NULL; +} + +static inline int +PyConfig_GetInt(const char *name, int *value) +{ + PyObject *obj = PyConfig_Get(name); + if (obj == NULL) { + return -1; + } + + if (!PyLong_Check(obj)) { + Py_DECREF(obj); + PyErr_Format(PyExc_TypeError, "config option %s is not an int", name); + return -1; + } + + int as_int = PyLong_AsInt(obj); + Py_DECREF(obj); + if (as_int == -1 && PyErr_Occurred()) { + PyErr_Format(PyExc_OverflowError, + "config option %s value does not fit into a C int", name); + return -1; + } + + *value = as_int; + return 0; +} +#endif // PY_VERSION_HEX > 0x03090000 && !defined(PYPY_VERSION) + +// gh-133144 added PyUnstable_Object_IsUniquelyReferenced() to Python 3.14.0b1. +// Adapted from _PyObject_IsUniquelyReferenced() implementation. +#if PY_VERSION_HEX < 0x030E00B0 +static inline int PyUnstable_Object_IsUniquelyReferenced(PyObject *obj) +{ +#if !defined(Py_GIL_DISABLED) + return Py_REFCNT(obj) == 1; +#else + // NOTE: the entire ob_ref_shared field must be zero, including flags, to + // ensure that other threads cannot concurrently create new references to + // this object. + return (_Py_IsOwnedByCurrentThread(obj) && + _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local) == 1 && + _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared) == 0); +#endif +} +#endif + +// gh-128926 added PyUnstable_TryIncRef() and PyUnstable_EnableTryIncRef() to +// Python 3.14.0a5. Adapted from _Py_TryIncref() and _PyObject_SetMaybeWeakref(). +#if PY_VERSION_HEX < 0x030E00A5 +static inline int PyUnstable_TryIncRef(PyObject *op) +{ +#ifndef Py_GIL_DISABLED + if (Py_REFCNT(op) > 0) { + Py_INCREF(op); + return 1; + } + return 0; +#else + // _Py_TryIncrefFast() + uint32_t local = _Py_atomic_load_uint32_relaxed(&op->ob_ref_local); + local += 1; + if (local == 0) { + // immortal + return 1; + } + if (_Py_IsOwnedByCurrentThread(op)) { + _Py_INCREF_STAT_INC(); + _Py_atomic_store_uint32_relaxed(&op->ob_ref_local, local); +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + return 1; + } + + // _Py_TryIncRefShared() + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); + for (;;) { + // If the shared refcount is zero and the object is either merged + // or may not have weak references, then we cannot incref it. + if (shared == 0 || shared == _Py_REF_MERGED) { + return 0; + } + + if (_Py_atomic_compare_exchange_ssize( + &op->ob_ref_shared, + &shared, + shared + (1 << _Py_REF_SHARED_SHIFT))) { +#ifdef Py_REF_DEBUG + _Py_INCREF_IncRefTotal(); +#endif + _Py_INCREF_STAT_INC(); + return 1; + } + } +#endif +} + +static inline void PyUnstable_EnableTryIncRef(PyObject *op) +{ +#ifdef Py_GIL_DISABLED + // _PyObject_SetMaybeWeakref() + if (_Py_IsImmortal(op)) { + return; + } + for (;;) { + Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared); + if ((shared & _Py_REF_SHARED_FLAG_MASK) != 0) { + // Nothing to do if it's in WEAKREFS, QUEUED, or MERGED states. + return; + } + if (_Py_atomic_compare_exchange_ssize( + &op->ob_ref_shared, &shared, shared | _Py_REF_MAYBE_WEAKREF)) { + return; + } + } +#else + (void)op; // unused argument +#endif +} +#endif + + +#if PY_VERSION_HEX < 0x030F0000 +static inline PyObject* +PySys_GetAttrString(const char *name) +{ +#if PY_VERSION_HEX >= 0x03000000 + PyObject *value = Py_XNewRef(PySys_GetObject(name)); +#else + PyObject *value = Py_XNewRef(PySys_GetObject((char*)name)); +#endif + if (value != NULL) { + return value; + } + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, "lost sys.%s", name); + } + return NULL; +} + +static inline PyObject* +PySys_GetAttr(PyObject *name) +{ +#if PY_VERSION_HEX >= 0x03000000 + const char *name_str = PyUnicode_AsUTF8(name); +#else + const char *name_str = PyString_AsString(name); +#endif + if (name_str == NULL) { + return NULL; + } + + return PySys_GetAttrString(name_str); +} + +static inline int +PySys_GetOptionalAttrString(const char *name, PyObject **value) +{ +#if PY_VERSION_HEX >= 0x03000000 + *value = Py_XNewRef(PySys_GetObject(name)); +#else + *value = Py_XNewRef(PySys_GetObject((char*)name)); +#endif + if (*value != NULL) { + return 1; + } + return 0; +} + +static inline int +PySys_GetOptionalAttr(PyObject *name, PyObject **value) +{ +#if PY_VERSION_HEX >= 0x03000000 + const char *name_str = PyUnicode_AsUTF8(name); +#else + const char *name_str = PyString_AsString(name); +#endif + if (name_str == NULL) { + *value = NULL; + return -1; + } + + return PySys_GetOptionalAttrString(name_str, value); +} +#endif // PY_VERSION_HEX < 0x030F00A1 + + +#if PY_VERSION_HEX < 0x030F00A1 +typedef struct PyBytesWriter { + char small_buffer[256]; + PyObject *obj; + Py_ssize_t size; +} PyBytesWriter; + +static inline Py_ssize_t +_PyBytesWriter_GetAllocated(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return sizeof(writer->small_buffer); + } + else { + return PyBytes_GET_SIZE(writer->obj); + } +} + + +static inline int +_PyBytesWriter_Resize_impl(PyBytesWriter *writer, Py_ssize_t size, + int resize) +{ + int overallocate = resize; + assert(size >= 0); + + if (size <= _PyBytesWriter_GetAllocated(writer)) { + return 0; + } + + if (overallocate) { +#ifdef MS_WINDOWS + /* On Windows, overallocate by 50% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 2)) { + size += size / 2; + } +#else + /* On Linux, overallocate by 25% is the best factor */ + if (size <= (PY_SSIZE_T_MAX - size / 4)) { + size += size / 4; + } +#endif + } + + if (writer->obj != NULL) { + if (_PyBytes_Resize(&writer->obj, size)) { + return -1; + } + assert(writer->obj != NULL); + } + else { + writer->obj = PyBytes_FromStringAndSize(NULL, size); + if (writer->obj == NULL) { + return -1; + } + + if (resize) { + assert((size_t)size > sizeof(writer->small_buffer)); + memcpy(PyBytes_AS_STRING(writer->obj), + writer->small_buffer, + sizeof(writer->small_buffer)); + } + } + return 0; +} + +static inline void* +PyBytesWriter_GetData(PyBytesWriter *writer) +{ + if (writer->obj == NULL) { + return writer->small_buffer; + } + else { + return PyBytes_AS_STRING(writer->obj); + } +} + +static inline Py_ssize_t +PyBytesWriter_GetSize(PyBytesWriter *writer) +{ + return writer->size; +} + +static inline void +PyBytesWriter_Discard(PyBytesWriter *writer) +{ + if (writer == NULL) { + return; + } + + Py_XDECREF(writer->obj); + PyMem_Free(writer); +} + +static inline PyBytesWriter* +PyBytesWriter_Create(Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return NULL; + } + + PyBytesWriter *writer = (PyBytesWriter*)PyMem_Malloc(sizeof(PyBytesWriter)); + if (writer == NULL) { + PyErr_NoMemory(); + return NULL; + } + + writer->obj = NULL; + writer->size = 0; + + if (size >= 1) { + if (_PyBytesWriter_Resize_impl(writer, size, 0) < 0) { + PyBytesWriter_Discard(writer); + return NULL; + } + writer->size = size; + } + return writer; +} + +static inline PyObject* +PyBytesWriter_FinishWithSize(PyBytesWriter *writer, Py_ssize_t size) +{ + PyObject *result; + if (size == 0) { + result = PyBytes_FromStringAndSize("", 0); + } + else if (writer->obj != NULL) { + if (size != PyBytes_GET_SIZE(writer->obj)) { + if (_PyBytes_Resize(&writer->obj, size)) { + goto error; + } + } + result = writer->obj; + writer->obj = NULL; + } + else { + result = PyBytes_FromStringAndSize(writer->small_buffer, size); + } + PyBytesWriter_Discard(writer); + return result; + +error: + PyBytesWriter_Discard(writer); + return NULL; +} + +static inline PyObject* +PyBytesWriter_Finish(PyBytesWriter *writer) +{ + return PyBytesWriter_FinishWithSize(writer, writer->size); +} + +static inline PyObject* +PyBytesWriter_FinishWithPointer(PyBytesWriter *writer, void *buf) +{ + Py_ssize_t size = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (size < 0 || size > _PyBytesWriter_GetAllocated(writer)) { + PyBytesWriter_Discard(writer); + PyErr_SetString(PyExc_ValueError, "invalid end pointer"); + return NULL; + } + + return PyBytesWriter_FinishWithSize(writer, size); +} + +static inline int +PyBytesWriter_Resize(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0) { + PyErr_SetString(PyExc_ValueError, "size must be >= 0"); + return -1; + } + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline int +PyBytesWriter_Grow(PyBytesWriter *writer, Py_ssize_t size) +{ + if (size < 0 && writer->size + size < 0) { + PyErr_SetString(PyExc_ValueError, "invalid size"); + return -1; + } + if (size > PY_SSIZE_T_MAX - writer->size) { + PyErr_NoMemory(); + return -1; + } + size = writer->size + size; + + if (_PyBytesWriter_Resize_impl(writer, size, 1) < 0) { + return -1; + } + writer->size = size; + return 0; +} + +static inline void* +PyBytesWriter_GrowAndUpdatePointer(PyBytesWriter *writer, + Py_ssize_t size, void *buf) +{ + Py_ssize_t pos = (char*)buf - (char*)PyBytesWriter_GetData(writer); + if (PyBytesWriter_Grow(writer, size) < 0) { + return NULL; + } + return (char*)PyBytesWriter_GetData(writer) + pos; +} + +static inline int +PyBytesWriter_WriteBytes(PyBytesWriter *writer, + const void *bytes, Py_ssize_t size) +{ + if (size < 0) { + size_t len = strlen((const char*)bytes); + if (len > (size_t)PY_SSIZE_T_MAX) { + PyErr_NoMemory(); + return -1; + } + size = (Py_ssize_t)len; + } + + Py_ssize_t pos = writer->size; + if (PyBytesWriter_Grow(writer, size) < 0) { + return -1; + } + char *buf = (char*)PyBytesWriter_GetData(writer); + memcpy(buf + pos, bytes, (size_t)size); + return 0; +} + +static inline int +PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) + Py_GCC_ATTRIBUTE((format(printf, 2, 3))); + +static inline int +PyBytesWriter_Format(PyBytesWriter *writer, const char *format, ...) +{ + va_list vargs; + va_start(vargs, format); + PyObject *str = PyBytes_FromFormatV(format, vargs); + va_end(vargs); + + if (str == NULL) { + return -1; + } + int res = PyBytesWriter_WriteBytes(writer, + PyBytes_AS_STRING(str), + PyBytes_GET_SIZE(str)); + Py_DECREF(str); + return res; +} +#endif // PY_VERSION_HEX < 0x030F00A1 + + +#if PY_VERSION_HEX < 0x030F00A1 +static inline PyObject* +PyTuple_FromArray(PyObject *const *array, Py_ssize_t size) +{ + PyObject *tuple = PyTuple_New(size); + if (tuple == NULL) { + return NULL; + } + for (Py_ssize_t i=0; i < size; i++) { + PyObject *item = array[i]; + PyTuple_SET_ITEM(tuple, i, Py_NewRef(item)); + } + return tuple; +} +#endif + + +#if PY_VERSION_HEX < 0x030F00A1 +static inline Py_hash_t +PyUnstable_Unicode_GET_CACHED_HASH(PyObject *op) +{ +#ifdef PYPY_VERSION + (void)op; // unused argument + return -1; +#elif PY_VERSION_HEX >= 0x03000000 + return ((PyASCIIObject*)op)->hash; +#else + return ((PyUnicodeObject*)op)->hash; +#endif +} +#endif + #ifdef __cplusplus } From 06aa3ef3d3cbd6bc680c357d09f4b4a6afaa2bed Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 9 Nov 2025 02:50:18 +0000 Subject: [PATCH 263/651] Move types from typing_extensions to typing (#167185) This PR moves some implemented types from typing_extensions to typing due to the recent update to Python 3.10. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167185 Approved by: https://github.com/janeyx99 --- torch/_C/__init__.pyi.in | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 16d71cd0abb2e..559230350bcc9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -3,23 +3,25 @@ # mypy: allow-untyped-defs # ruff: noqa: F401 -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from enum import Enum, IntEnum from pathlib import Path from types import EllipsisType from typing import ( Any, AnyStr, - Callable, Generic, IO, Literal, NamedTuple, overload, + Protocol, + runtime_checkable, SupportsIndex, + TypeAlias, TypeVar, ) -from typing_extensions import ParamSpec, Protocol, runtime_checkable, Self, TypeAlias +from typing_extensions import ParamSpec, Self import numpy From 9cf623a209c0a7bc3b88cb654c0d074952244545 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sun, 9 Nov 2025 05:08:00 +0000 Subject: [PATCH 264/651] Update inductor-unittest.yml (#167417) i see failures like https://github.com/pytorch/pytorch/actions/runs/19189378182/job/54865171317?pr=167389 maybe this will fix it Pull Request resolved: https://github.com/pytorch/pytorch/pull/167417 Approved by: https://github.com/yf225 --- .github/workflows/inductor-unittest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index f55267caba93f..af9829c96f506 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -93,7 +93,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" }, ]} secrets: inherit From e7c1905837663d358e6299c09090455f4060b0ed Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Sun, 9 Nov 2025 05:38:09 +0000 Subject: [PATCH 265/651] Fix test_fsdp_logging (#167312) - The logger name in test_fully_shard_logging.py was wrong so the logs didn't happen. - The `device` variable in test_fully_shard_logging is expected to be a string, so quote it - `unittest.skipIf` is used so importing `unittest` instead of `unittest.mock` is required Pull Request resolved: https://github.com/pytorch/pytorch/pull/167312 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- .../_composable/fsdp/test_fully_shard_logging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_logging.py b/test/distributed/_composable/fsdp/test_fully_shard_logging.py index c9450a2b8f475..9b666eb55ba08 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_logging.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_logging.py @@ -1,7 +1,7 @@ # Owner(s): ["module: fsdp"] import functools import os -import unittest.mock +import unittest import torch.distributed as dist from torch._dynamo.test_case import run_tests @@ -37,9 +37,9 @@ def test_fsdp_logging(self): import torch.distributed as dist import torch.nn as nn from torch.distributed.fsdp import fully_shard -logger = logging.getLogger("torch.distributed._composable.fsdp") +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") logger.setLevel(logging.DEBUG) -device = {device_type.type} +device = '{device_type.type}' torch.manual_seed(0) model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]) for layer in model: From 5135ace3a3ca836338201a08404ac36c96b58b7c Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 9 Nov 2025 06:40:00 +0000 Subject: [PATCH 266/651] Enable ruff UP035 rule (#167307) This PR enables `UP035` rule of ruff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167307 Approved by: https://github.com/Lucaskabela --- .ci/lumen_cli/cli/lib/common/cli_helper.py | 6 ++++-- .github/scripts/delete_old_branches.py | 3 ++- .github/scripts/filter_test_configs.py | 3 ++- .github/scripts/get_workflow_job_id.py | 3 ++- .github/scripts/github_utils.py | 3 ++- .github/scripts/gitutils.py | 4 ++-- .github/scripts/trymerge.py | 4 ++-- pyproject.toml | 1 - test/inductor/test_mem_estimation.py | 5 +++-- torch/_C/_VariableFunctions.pyi.in | 4 ++-- torch/_dynamo/variables/ctx_manager.py | 6 +++--- torch/_dynamo/variables/lists.py | 3 ++- torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py | 2 +- torch/jit/_shape_functions.py | 3 ++- torch/nn/functional.pyi.in | 5 ++--- torch/onnx/_internal/exporter/_torchlib/ops/nn.py | 2 +- torch/testing/_internal/fake_config_module3.py | 2 +- torch/utils/data/datapipes/datapipe.pyi.in | 4 ++-- 18 files changed, 35 insertions(+), 28 deletions(-) diff --git a/.ci/lumen_cli/cli/lib/common/cli_helper.py b/.ci/lumen_cli/cli/lib/common/cli_helper.py index 927ca09fe7230..4086eb7d46e81 100644 --- a/.ci/lumen_cli/cli/lib/common/cli_helper.py +++ b/.ci/lumen_cli/cli/lib/common/cli_helper.py @@ -8,9 +8,11 @@ try: - from typing import Any, Callable, Required, TypedDict # Python 3.11+ + from collections.abc import Callable # Python 3.11+ + from typing import Any, Required, TypedDict except ImportError: - from typing import Any, Callable, TypedDict + from collections.abc import Callable + from typing import Any, TypedDict from typing_extensions import Required # Fallback for Python <3.11 diff --git a/.github/scripts/delete_old_branches.py b/.github/scripts/delete_old_branches.py index 8032008edf122..42cd851f8e338 100644 --- a/.github/scripts/delete_old_branches.py +++ b/.github/scripts/delete_old_branches.py @@ -1,10 +1,11 @@ # Delete old branches import os import re +from collections.abc import Callable from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import Any, Callable +from typing import Any from github_utils import gh_fetch_json_dict, gh_graphql from gitutils import GitRepo diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 592c7aab6d933..ee102d3f560f9 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -8,10 +8,11 @@ import subprocess import sys import warnings +from collections.abc import Callable from enum import Enum from functools import cache from logging import info -from typing import Any, Callable, Optional +from typing import Any, Optional from urllib.request import Request, urlopen import yaml diff --git a/.github/scripts/get_workflow_job_id.py b/.github/scripts/get_workflow_job_id.py index b04cbed76e955..54e66621c9fd0 100644 --- a/.github/scripts/get_workflow_job_id.py +++ b/.github/scripts/get_workflow_job_id.py @@ -11,7 +11,8 @@ import time import urllib import urllib.parse -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional from urllib.request import Request, urlopen diff --git a/.github/scripts/github_utils.py b/.github/scripts/github_utils.py index 110015988a5c3..6479fb64ddbaf 100644 --- a/.github/scripts/github_utils.py +++ b/.github/scripts/github_utils.py @@ -3,8 +3,9 @@ import json import os import warnings +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, cast, Optional, Union +from typing import Any, cast, Optional, Union from urllib.error import HTTPError from urllib.parse import quote from urllib.request import Request, urlopen diff --git a/.github/scripts/gitutils.py b/.github/scripts/gitutils.py index 3a90ddb5f4c6b..6e3bb3f209177 100644 --- a/.github/scripts/gitutils.py +++ b/.github/scripts/gitutils.py @@ -4,10 +4,10 @@ import re import tempfile from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Callable, Iterator from datetime import datetime from functools import wraps -from typing import Any, Callable, cast, Optional, TypeVar, Union +from typing import Any, cast, Optional, TypeVar, Union T = TypeVar("T") diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index c258284a00d83..697ab6992793d 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -17,12 +17,12 @@ import time import urllib.parse from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass from functools import cache from pathlib import Path from re import Pattern -from typing import Any, Callable, cast, NamedTuple, Optional +from typing import Any, cast, NamedTuple, Optional from warnings import warn import yaml diff --git a/pyproject.toml b/pyproject.toml index 21a1f2ec1e3e8..b01ba623cc814 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -184,7 +184,6 @@ ignore = [ "TC006", # TODO: Remove Python-3.10 specific suppressions "B905", - "UP035", ] select = [ "B", diff --git a/test/inductor/test_mem_estimation.py b/test/inductor/test_mem_estimation.py index 4b49982c6377d..2f0ccfe6b284d 100644 --- a/test/inductor/test_mem_estimation.py +++ b/test/inductor/test_mem_estimation.py @@ -3,7 +3,8 @@ import functools import weakref from collections import Counter -from typing import Callable, Optional +from collections.abc import Callable +from typing import Optional import torch from torch._inductor.fx_passes.memory_estimator import ( @@ -28,7 +29,7 @@ def device_filter(device): class FakeTensorMemoryProfilerMode(TorchDispatchMode): - def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None): + def __init__(self, device_filter: Optional[Callable[[torch.device], bool]] = None): # counter of storage ids to live references self.storage_count: dict[int, int] = Counter() # live fake tensors diff --git a/torch/_C/_VariableFunctions.pyi.in b/torch/_C/_VariableFunctions.pyi.in index 374f5661060e0..2a633b401ca79 100644 --- a/torch/_C/_VariableFunctions.pyi.in +++ b/torch/_C/_VariableFunctions.pyi.in @@ -3,9 +3,9 @@ # mypy: allow-untyped-defs # ruff: noqa: F401,PYI054 -from collections.abc import Sequence +from collections.abc import Callable, Sequence from types import EllipsisType -from typing import Any, Callable, Literal, overload, TypeVar +from typing import Any, Literal, overload, TypeVar import torch from torch import ( diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 81bb0777b5555..c79f19216f68b 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -22,8 +22,8 @@ import sys import warnings from collections.abc import Callable, Sequence, Sized -from contextlib import ExitStack -from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union +from contextlib import AbstractContextManager, ExitStack +from typing import Any, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard @@ -163,7 +163,7 @@ def cleanup_assert(self) -> None: class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are # python constants. Which might not always be the case here. - def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None: + def __init__(self, cm_obj: AbstractContextManager[Any], **kwargs: Any) -> None: assert cm_obj is not None super().__init__( value=cm_obj, diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 3c525312198c8..2ac355bd53417 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -18,7 +18,8 @@ class that handles its unique behaviors while integrating with Dynamo's import inspect import operator import sys -from typing import Any, Optional, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, Optional, TYPE_CHECKING import torch import torch.fx diff --git a/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py b/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py index c215790770420..8fe46cf75d8c6 100644 --- a/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py +++ b/torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py @@ -2,7 +2,7 @@ # fmt: off # This file was generated by AutoHeuristic. Do not modify it manually! # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/mixed_mm/ -from typing import List, Optional, Tuple +from typing import Optional from torch._inductor.autoheuristic.autoheuristic_utils import ( AHContext, diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index f2a6f4a841763..1f95de46f6f2a 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable +from typing import Any, Optional, Union number = Union[int, float] diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index d0b64447e900b..f902268925138 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -1,9 +1,8 @@ # ${generated_comment} # mypy: allow-untyped-defs -from collections.abc import Sequence -from typing import Any, Callable, Literal, overload -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, Literal, overload, TypeAlias from torch import Tensor from torch.types import _dtype, _int, _size diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 31f87046315b6..700ec6ae42543 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional, Sequence, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING # noqa: UP035 from onnxscript.onnx_opset import ( # type: ignore[attr-defined] opset20 as op20, diff --git a/torch/testing/_internal/fake_config_module3.py b/torch/testing/_internal/fake_config_module3.py index 1d3d7f15d901a..ff4118438e74c 100644 --- a/torch/testing/_internal/fake_config_module3.py +++ b/torch/testing/_internal/fake_config_module3.py @@ -1,5 +1,5 @@ import sys -from typing import Callable, Optional +from typing import Callable, Optional # noqa: UP035 from torch.utils._config_module import install_config_module diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index 73cfa120e4944..a7b7bac21f50d 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -5,8 +5,8 @@ # Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other # classes/objects here, even though we are not injecting extra code into them at the moment. -from collections.abc import Iterable, Iterator -from typing import Any, Callable, Literal, Optional, TypeVar, Union +from collections.abc import Callable, Iterable, Iterator +from typing import Any, Literal, Optional, TypeVar, Union from torch.utils.data import Dataset, default_collate, IterableDataset from torch.utils.data.datapipes._hook_iterator import _SnapshotState From 14a845a4ec67cde1ae5467653ad3b8db3b8bde60 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 9 Nov 2025 12:11:45 +0000 Subject: [PATCH 267/651] [2/N] Use Python 3.10 typing (#167167) This PR applies new `Union` and `Optional` typing syntax to some files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167167 Approved by: https://github.com/XuehaiPan, https://github.com/mlazos --- torch/utils/_config_module.py | 70 ++++----- torch/utils/_content_store.py | 9 +- torch/utils/_cpp_embed_headers.py | 5 +- torch/utils/_cxx_pytree.py | 106 ++++++------- torch/utils/_debug_mode.py | 40 +++-- torch/utils/_device.py | 3 +- torch/utils/_filelock.py | 7 +- torch/utils/_foreach_utils.py | 4 +- torch/utils/_import_utils.py | 3 +- torch/utils/_ordered_set.py | 4 +- torch/utils/_python_dispatch.py | 30 ++-- torch/utils/_pytree.py | 139 +++++++++--------- .../_strobelight/cli_function_profiler.py | 22 +-- torch/utils/_sympy/functions.py | 40 +++-- torch/utils/_sympy/interp.py | 4 +- torch/utils/_sympy/printers.py | 7 +- torch/utils/_sympy/reference.py | 4 +- torch/utils/_sympy/solve.py | 5 +- torch/utils/_sympy/symbol.py | 5 +- torch/utils/_sympy/value_ranges.py | 33 ++--- torch/utils/_thunk.py | 6 +- torch/utils/_traceback.py | 3 +- torch/utils/_typing_utils.py | 4 +- torch/utils/backend_registration.py | 11 +- .../benchmark/op_fuzzers/sparse_unary.py | 4 +- torch/utils/benchmark/utils/common.py | 14 +- torch/utils/benchmark/utils/compare.py | 13 +- torch/utils/benchmark/utils/compile.py | 32 ++-- torch/utils/benchmark/utils/cpp_jit.py | 8 +- torch/utils/benchmark/utils/fuzzer.py | 36 ++--- torch/utils/benchmark/utils/sparse_fuzzer.py | 17 +-- torch/utils/benchmark/utils/timer.py | 28 ++-- .../utils/valgrind_wrapper/timer_interface.py | 30 ++-- torch/utils/bundled_inputs.py | 22 +-- torch/utils/cpp_extension.py | 37 +++-- torch/utils/data/_utils/collate.py | 17 +-- torch/utils/data/_utils/worker.py | 8 +- torch/utils/data/dataloader.py | 24 +-- torch/utils/data/datapipes/_decorator.py | 6 +- .../datapipes/dataframe/dataframe_wrapper.py | 4 +- .../data/datapipes/dataframe/dataframes.py | 4 +- torch/utils/data/datapipes/datapipe.py | 24 +-- torch/utils/data/datapipes/datapipe.pyi.in | 12 +- torch/utils/data/datapipes/gen_pyi.py | 4 +- torch/utils/data/datapipes/iter/callable.py | 10 +- .../data/datapipes/iter/combinatorics.py | 8 +- torch/utils/data/datapipes/iter/combining.py | 18 +-- torch/utils/data/datapipes/iter/filelister.py | 7 +- torch/utils/data/datapipes/iter/fileopener.py | 5 +- torch/utils/data/datapipes/iter/grouping.py | 6 +- .../utils/data/datapipes/iter/streamreader.py | 3 +- .../utils/data/datapipes/map/combinatorics.py | 6 +- torch/utils/data/datapipes/map/utils.py | 6 +- torch/utils/data/datapipes/utils/common.py | 12 +- torch/utils/data/dataset.py | 8 +- torch/utils/data/distributed.py | 6 +- torch/utils/data/graph.py | 4 +- torch/utils/data/graph_settings.py | 6 +- torch/utils/data/sampler.py | 6 +- torch/utils/data/typing.ipynb | 4 +- torch/utils/dlpack.py | 6 +- torch/utils/flop_counter.py | 12 +- torch/utils/hipify/hipify_python.py | 3 +- torch/utils/mobile_optimizer.py | 6 +- torch/utils/serialization/config.py | 2 +- torch/utils/tensorboard/_proto_graph.py | 11 +- torch/utils/tensorboard/summary.py | 4 +- torch/utils/tensorboard/writer.py | 6 +- torch/utils/viz/_cycles.py | 4 +- 69 files changed, 523 insertions(+), 554 deletions(-) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index ca298219560e8..16fbad73a3097 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import FunctionType, ModuleType -from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar from typing_extensions import deprecated from unittest import mock @@ -23,7 +23,7 @@ # Duplicated, because mypy needs these types statically -T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict]) +T = TypeVar("T", bound=int | float | bool | None | str | list | set | tuple | dict) _UNSET_SENTINEL = object() @@ -69,12 +69,12 @@ class _Config(Generic[T]): default behaviour. I.e. user overrides take preference. """ - default: Union[T, object] - justknob: Optional[str] = None - env_name_default: Optional[list[str]] = None - env_name_force: Optional[list[str]] = None - value_type: Optional[type] = None - alias: Optional[str] = None + default: T | object + justknob: str | None = None + env_name_default: list[str] | None = None + env_name_force: list[str] | None = None + value_type: type | None = None + alias: str | None = None def __post_init__(self) -> None: self.env_name_default = _Config.string_or_list_of_string_to_list( @@ -98,8 +98,8 @@ def __post_init__(self) -> None: @staticmethod def string_or_list_of_string_to_list( - val: Optional[Union[str, list[str]]], - ) -> Optional[list[str]]: + val: str | list[str] | None, + ) -> list[str] | None: if val is None: return None if isinstance(val, str): @@ -116,23 +116,23 @@ def string_or_list_of_string_to_list( if TYPE_CHECKING: def Config( - default: Union[T, object] = _UNSET_SENTINEL, - justknob: Optional[str] = None, - env_name_default: Optional[Union[str, list[str]]] = None, - env_name_force: Optional[Union[str, list[str]]] = None, - value_type: Optional[type] = None, - alias: Optional[str] = None, + default: T | object = _UNSET_SENTINEL, + justknob: str | None = None, + env_name_default: str | list[str] | None = None, + env_name_force: str | list[str] | None = None, + value_type: type | None = None, + alias: str | None = None, ) -> T: ... else: def Config( - default: Union[T, object] = _UNSET_SENTINEL, - justknob: Optional[str] = None, - env_name_default: Optional[Union[str, list[str]]] = None, - env_name_force: Optional[Union[str, list[str]]] = None, - value_type: Optional[type] = None, - alias: Optional[str] = None, + default: T | object = _UNSET_SENTINEL, + justknob: str | None = None, + env_name_default: str | list[str] | None = None, + env_name_force: str | list[str] | None = None, + value_type: type | None = None, + alias: str | None = None, ) -> _Config[T]: return _Config( default=default, @@ -144,7 +144,7 @@ def Config( ) -def _read_env_variable(name: str) -> Optional[Union[bool, str]]: +def _read_env_variable(name: str) -> bool | str | None: value = os.environ.get(name) if value == "1": return True @@ -165,8 +165,8 @@ class ConfigModuleInstance(ConfigModule): _bypass_keys = set({"_is_dirty", "_hash_digest", "__annotations__"}) def visit( - source: Union[ModuleType, type], - dest: Union[ModuleType, SubConfigProxy], + source: ModuleType | type, + dest: ModuleType | SubConfigProxy, prefix: str, ) -> None: """Walk the module structure and move everything to module._config""" @@ -281,7 +281,7 @@ class _ConfigEntry: # _UNSET_SENTINEL indicates the value is not set. user_override: Any = _UNSET_SENTINEL # The justknob to check for this config - justknob: Optional[str] = None + justknob: str | None = None # environment variables are read at install time env_value_force: Any = _UNSET_SENTINEL env_value_default: Any = _UNSET_SENTINEL @@ -297,7 +297,7 @@ class _ConfigEntry: # call so the final state is correct. It's just very unintuitive. # upstream bug - python/cpython#126886 hide: bool = False - alias: Optional[str] = None + alias: str | None = None def __init__(self, config: _Config) -> None: self.default = config.default @@ -347,7 +347,7 @@ class ConfigModule(ModuleType): _bypass_keys: set[str] _compile_ignored_keys: set[str] _is_dirty: bool - _hash_digest: Optional[bytes] + _hash_digest: bytes | None def __init__(self) -> None: raise NotImplementedError( @@ -411,7 +411,7 @@ def __delattr__(self, name: str) -> None: def _get_alias_module_and_name( self, entry: _ConfigEntry - ) -> Optional[tuple[ModuleType, str]]: + ) -> tuple[ModuleType, str] | None: alias = entry.alias if alias is None: return None @@ -465,8 +465,8 @@ def _is_default(self, name: str) -> bool: def _get_dict( self, - ignored_keys: Optional[list[str]] = None, - ignored_prefixes: Optional[list[str]] = None, + ignored_keys: list[str] | None = None, + ignored_prefixes: list[str] | None = None, skip_default: bool = False, ) -> dict[str, Any]: """Export a dictionary of current configuration keys and values. @@ -542,7 +542,7 @@ def add_import(func: Callable) -> None: if module_name: imports.add(module_name) - def list_of_callables_to_string(v: Union[list, set]) -> list[str]: + def list_of_callables_to_string(v: list | set) -> list[str]: return [f"{get_module_name(item, True)}{item.__name__}" for item in v] def importable_callable(v: Any) -> bool: @@ -615,7 +615,7 @@ def to_dict(self) -> dict[str, Any]: def shallow_copy_dict(self) -> dict[str, Any]: return self.get_config_copy() - def load_config(self, maybe_pickled_config: Union[bytes, dict[str, Any]]) -> None: + def load_config(self, maybe_pickled_config: bytes | dict[str, Any]) -> None: """Restore from a prior call to save_config() or shallow_copy_dict()""" if not isinstance(maybe_pickled_config, dict): config = pickle.loads(maybe_pickled_config) @@ -637,7 +637,7 @@ def get_serializable_config_copy(self) -> dict[str, Any]: def patch( self, - arg1: Optional[Union[str, dict[str, Any]]] = None, + arg1: str | dict[str, Any] | None = None, arg2: Any = None, **kwargs: dict[str, Any], ) -> "ContextDecorator": @@ -816,7 +816,7 @@ def patch_object(obj: object, name: str, value: object) -> object: return mock.patch.object(obj, name, value) -def get_tristate_env(name: str, default: Any = None) -> Optional[bool]: +def get_tristate_env(name: str, default: Any = None) -> bool | None: value = os.environ.get(name) if value == "1": return True diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py index 0086a1e874ddf..234355210057a 100644 --- a/torch/utils/_content_store.py +++ b/torch/utils/_content_store.py @@ -34,7 +34,6 @@ import os.path import struct from collections import defaultdict -from typing import Optional import torch import torch._prims as prims @@ -193,9 +192,9 @@ def write_tensor(self, name: str, t: torch.Tensor) -> None: class ContentStoreReader: def __init__(self, loc: str, *, cache=True) -> None: self.loc = loc - self.storage_cache: Optional[ - dict[Optional[torch.device], dict[str, StorageWeakRef]] - ] = None + self.storage_cache: ( + dict[torch.device | None, dict[str, StorageWeakRef]] | None + ) = None if cache: self.storage_cache = defaultdict(dict) @@ -207,7 +206,7 @@ def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage: if self.storage_cache is not None else None ) - s: Optional[torch.UntypedStorage] + s: torch.UntypedStorage | None if ws is not None: s = torch.UntypedStorage._new_with_weak_ptr(ws.cdata) if s is not None: diff --git a/torch/utils/_cpp_embed_headers.py b/torch/utils/_cpp_embed_headers.py index 1d1577b0d8cb5..88ab41aadffe7 100644 --- a/torch/utils/_cpp_embed_headers.py +++ b/torch/utils/_cpp_embed_headers.py @@ -1,10 +1,9 @@ from collections.abc import Sequence from pathlib import Path from re import match as _match -from typing import Optional, Union -def read_file(fname: Union[Path, str]) -> list[str]: +def read_file(fname: Path | str) -> list[str]: with open(fname, encoding="utf-8") as f: return f.readlines() @@ -36,7 +35,7 @@ def _embed_headers( def embed_headers( - fname: str, include_dirs: Optional[Union[Sequence[str], Sequence[Path], str]] = None + fname: str, include_dirs: Sequence[str] | Sequence[Path] | str | None = None ) -> str: if include_dirs is None: base_dir = Path(__file__).parent.parent.parent diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 897279bd39b1e..f9350124d135a 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,7 +15,7 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing import Any, overload, TypeAlias, TypeVar, Union from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree @@ -128,10 +128,10 @@ def register_pytree_node( flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, + flatten_with_keys_fn: FlattenWithKeysFunc | None = None, ) -> None: """Register a container-like type as pytree node. @@ -196,9 +196,9 @@ def _register_pytree_node( flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, ) -> None: """Register a container-like type as pytree node for the C++ pytree only. @@ -247,9 +247,9 @@ def _private_register_pytree_node( flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, ) -> None: """This is an internal function that is used to register a pytree node type for the C++ pytree only. End-users should use :func:`register_pytree_node` @@ -281,7 +281,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: def treespec_dict( - mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), + mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (), /, **kwargs: TreeSpec, ) -> TreeSpec: @@ -296,7 +296,7 @@ def treespec_dict( def tree_is_leaf( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: """Check if a pytree is a leaf. @@ -334,7 +334,7 @@ def tree_is_leaf( def tree_flatten( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> tuple[list[Any], TreeSpec]: """Flatten a pytree. @@ -399,7 +399,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: def tree_iter( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> Iterable[Any]: """Get an iterator over the leaves of a pytree. @@ -434,7 +434,7 @@ def tree_iter( def tree_leaves( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any]: """Get the leaves of a pytree. @@ -469,7 +469,7 @@ def tree_leaves( def tree_structure( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> TreeSpec: """Get the treespec for a pytree. @@ -506,7 +506,7 @@ def tree_map( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Map a multi-input function over pytree args to produce a new pytree. @@ -555,7 +555,7 @@ def tree_map_( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. @@ -593,8 +593,8 @@ def tree_map_( Type3 = tuple[type[T], type[S], type[U]] TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] -Fn2 = Callable[[Union[T, S]], R] -Fn3 = Callable[[Union[T, S, U]], R] +Fn2 = Callable[[T | S], R] +Fn3 = Callable[[T | S | U], R] Fn = Callable[[T], R] FnAny = Callable[[Any], R] @@ -629,7 +629,7 @@ def map_only( def map_only( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], / + type_or_types_or_pred: TypeAny | Callable[[Any], bool], / ) -> MapOnlyFn[FnAny[Any]]: """ Suppose you are writing a tree_map over tensors, leaving everything @@ -677,7 +677,7 @@ def tree_map_only( /, func: Fn[T, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -687,7 +687,7 @@ def tree_map_only( /, func: Fn2[T, S, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -697,7 +697,7 @@ def tree_map_only( /, func: Fn3[T, S, U, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -707,7 +707,7 @@ def tree_map_only( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -717,16 +717,16 @@ def tree_map_only( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... def tree_map_only( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + type_or_types_or_pred: TypeAny | Callable[[Any], bool], /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -737,7 +737,7 @@ def tree_map_only_( /, func: Fn[T, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -747,7 +747,7 @@ def tree_map_only_( /, func: Fn2[T, S, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -757,7 +757,7 @@ def tree_map_only_( /, func: Fn3[T, S, U, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -767,7 +767,7 @@ def tree_map_only_( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -777,16 +777,16 @@ def tree_map_only_( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... def tree_map_only_( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + type_or_types_or_pred: TypeAny | Callable[[Any], bool], /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -794,7 +794,7 @@ def tree_map_only_( def tree_all( pred: Callable[[Any], bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return all(map(pred, flat_args)) @@ -803,7 +803,7 @@ def tree_all( def tree_any( pred: Callable[[Any], bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return any(map(pred, flat_args)) @@ -815,7 +815,7 @@ def tree_all_only( /, pred: Fn[T, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -825,7 +825,7 @@ def tree_all_only( /, pred: Fn2[T, S, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -835,7 +835,7 @@ def tree_all_only( /, pred: Fn3[T, S, U, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -844,7 +844,7 @@ def tree_all_only( /, pred: FnAny[bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return all(pred(x) for x in flat_args if isinstance(x, type_or_types)) @@ -856,7 +856,7 @@ def tree_any_only( /, pred: Fn[T, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -866,7 +866,7 @@ def tree_any_only( /, pred: Fn2[T, S, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -876,7 +876,7 @@ def tree_any_only( /, pred: Fn3[T, S, U, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -885,7 +885,7 @@ def tree_any_only( /, pred: FnAny[bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return any(pred(x) for x in flat_args if isinstance(x, type_or_types)) @@ -894,7 +894,7 @@ def tree_any_only( def broadcast_prefix( prefix_tree: PyTree, full_tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any]: """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``. @@ -956,8 +956,8 @@ def add_leaves(x: Any, subtree: PyTree) -> None: def _broadcast_to_and_flatten( tree: PyTree, treespec: TreeSpec, - is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> Optional[list[Any]]: + is_leaf: Callable[[PyTree], bool] | None = None, +) -> list[Any] | None: if not _is_pytreespec_instance(treespec): raise AssertionError( f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}" @@ -969,7 +969,7 @@ def _broadcast_to_and_flatten( return None -def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: +def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str: """Serialize a treespec to a JSON string.""" if not _is_pytreespec_instance(treespec): raise TypeError( @@ -1024,7 +1024,7 @@ def __new__(cls) -> Self: def tree_flatten_with_path( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]: """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. @@ -1047,7 +1047,7 @@ def tree_flatten_with_path( def tree_leaves_with_path( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[tuple[KeyPath, Any]]: """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. @@ -1070,7 +1070,7 @@ def tree_map_with_path( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Like :func:`tree_map`, but the provided callable takes an additional key path argument. diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5c8bc9221a957..276f1e5631d23 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -4,7 +4,7 @@ import traceback import weakref from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -140,7 +140,7 @@ def _get_stack_trace() -> str: return "".join(summary.format()) -def _maybe_get_autograd_trace() -> Optional[str]: +def _maybe_get_autograd_trace() -> str | None: if torch._C._current_autograd_node() is not None: tb = torch._C._current_autograd_node().metadata.get("traceback_") # type: ignore[attr-defined] if tb: @@ -154,8 +154,8 @@ class _DebugCall: def __init__( self, call_depth: int, - record: Optional[dict[str, Any]] = None, - log: Optional[dict[str, Any]] = None, + record: dict[str, Any] | None = None, + log: dict[str, Any] | None = None, stack: bool = False, ) -> None: self.call_depth = call_depth @@ -166,10 +166,10 @@ def __init__( # results from dispatch hooks self.record = record self.log = log - self.output_str: Optional[str] = None + self.output_str: str | None = None def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + self, attributes: list[str], tensor_memo: TensorIdTracker | None = None ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. @@ -182,7 +182,7 @@ def stringify_output( self, output: Any, attributes: list[str], - tensor_memo: Optional[TensorIdTracker] = None, + tensor_memo: TensorIdTracker | None = None, ) -> None: """Store stringified version of call output in self.output_str""" if tree_all(lambda x: x is None, output): @@ -213,11 +213,11 @@ def __init__( self.args = args self.kwargs = kwargs - self.args_str: Optional[str] = None - self.kwargs_str: Optional[str] = None + self.args_str: str | None = None + self.kwargs_str: str | None = None def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + self, attributes: list[str], tensor_memo: TensorIdTracker | None = None ) -> None: self.args_str = ", ".join( _arg_to_str(arg, attributes, tensor_memo) for arg in self.args @@ -289,10 +289,10 @@ def __init__( self.dst_placement = dst_placement self.transform_info_str = transform_info_str - self.arg_str: Optional[str] = None + self.arg_str: str | None = None def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + self, attributes: list[str], tensor_memo: TensorIdTracker | None = None ) -> None: self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg @@ -339,7 +339,7 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False) -> No self.module_name = module_name def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + self, attributes: list[str], tensor_memo: TensorIdTracker | None = None ) -> None: pass # nothing to stringify @@ -418,7 +418,7 @@ def __init__( # This flag currently has no effect on torch.compiled-regions. self.record_nn_module = record_nn_module - self.module_tracker: Optional[ModTracker] = None + self.module_tracker: ModTracker | None = None if self.record_nn_module: self.module_tracker_setup() @@ -585,7 +585,7 @@ def record_redistribute_calls( arg, src_placement, dst_placement, - transform_info_str: Optional[str] = None, + transform_info_str: str | None = None, ): try: self._record_call( @@ -615,8 +615,8 @@ def debug_string(self) -> str: @staticmethod @contextlib.contextmanager def dispatch_hooks( - record_hook: Optional[Callable] = None, - log_hook: Optional[Callable] = None, + record_hook: Callable | None = None, + log_hook: Callable | None = None, ): """ Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls; @@ -660,9 +660,7 @@ def dispatch_hook(func, types, args, kwargs, result): @staticmethod @contextlib.contextmanager - def log_tensor_hashes( - hash_fn: Optional[Callable] = None, hash_inputs: bool = False - ): + def log_tensor_hashes(hash_fn: Callable | None = None, hash_inputs: bool = False): """ Installs hook for tensor hash logging. @@ -696,7 +694,7 @@ def _dispatch_hash_hook(func, types, args, kwargs, result): yield -def get_active_debug_mode() -> Optional[DebugMode]: +def get_active_debug_mode() -> DebugMode | None: debug_mode = None for mode in _get_current_dispatch_mode_stack(): if isinstance(mode, DebugMode): diff --git a/torch/utils/_device.py b/torch/utils/_device.py index e7e44719e0c57..aafa336415ec6 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-defs import functools -from typing import Optional import torch from torch._C import _len_torch_function_stack @@ -8,7 +7,7 @@ from torch.utils._contextlib import context_decorator -CURRENT_DEVICE: Optional[torch.device] = None +CURRENT_DEVICE: torch.device | None = None @functools.lru_cache(1) diff --git a/torch/utils/_filelock.py b/torch/utils/_filelock.py index dabf3bdc5fed8..a291f59b4ba7f 100644 --- a/torch/utils/_filelock.py +++ b/torch/utils/_filelock.py @@ -1,5 +1,4 @@ from types import TracebackType -from typing import Optional from typing_extensions import Self from filelock import FileLock as base_FileLock @@ -28,9 +27,9 @@ def __enter__(self) -> Self: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: self.region_counter.__exit__() with _WaitCounter("pytorch.filelock.exit").guard(): diff --git a/torch/utils/_foreach_utils.py b/torch/utils/_foreach_utils.py index 8b682d96c1918..e88720a93ce3f 100644 --- a/torch/utils/_foreach_utils.py +++ b/torch/utils/_foreach_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, TypeAlias +from typing import TypeAlias import torch from torch import Tensor @@ -23,7 +23,7 @@ def _get_fused_kernels_supported_devices() -> list[str]: ] -TensorListList: TypeAlias = list[list[Optional[Tensor]]] +TensorListList: TypeAlias = list[list[Tensor | None]] Indices: TypeAlias = list[int] _foreach_supported_types = [torch.Tensor] diff --git a/torch/utils/_import_utils.py b/torch/utils/_import_utils.py index 240f92acacb9d..47e48fb7144e5 100644 --- a/torch/utils/_import_utils.py +++ b/torch/utils/_import_utils.py @@ -1,7 +1,6 @@ import functools import importlib.util from types import ModuleType -from typing import Optional def _check_module_exists(name: str) -> bool: @@ -24,7 +23,7 @@ def dill_available() -> bool: @functools.lru_cache -def import_dill() -> Optional[ModuleType]: +def import_dill() -> ModuleType | None: if not dill_available(): return None diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index fdb9a914bf64e..f00b4ac31ef74 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -8,7 +8,7 @@ Reversible, Set as AbstractSet, ) -from typing import Any, cast, Optional, TypeVar +from typing import Any, cast, TypeVar T = TypeVar("T", bound=Hashable) @@ -24,7 +24,7 @@ class OrderedSet(MutableSet[T], Reversible[T]): __slots__ = ("_dict",) - def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: + def __init__(self, iterable: Iterable[T] | None = None) -> None: self._dict = dict.fromkeys(iterable, None) if iterable is not None else {} @staticmethod diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index b07b7a4ec6fe5..e90db26ef7952 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -6,7 +6,7 @@ import warnings from collections import deque from dataclasses import dataclass -from typing import cast, Optional, overload, Protocol, TYPE_CHECKING, Union +from typing import cast, overload, Protocol, TYPE_CHECKING from typing_extensions import TypeIs import torch @@ -207,7 +207,7 @@ def f(x): return False -def _get_current_dispatch_mode() -> Optional[TorchDispatchMode]: +def _get_current_dispatch_mode() -> TorchDispatchMode | None: """ Return the top user mode on the stack (the next one that would be executed) if there are any. @@ -308,7 +308,7 @@ def _push_mode(mode: TorchDispatchMode) -> None: _set_mode_pre_dispatch(mode) -def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None): +def _pop_mode(k: DispatchKey | torch._C._TorchDispatchModeKey | None = None): if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined] from torch._ops import _pop_mode_from_pre_dispatch @@ -319,7 +319,7 @@ def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = @contextlib.contextmanager -def _pop_mode_temporarily(k: Optional[DispatchKey] = None): +def _pop_mode_temporarily(k: DispatchKey | None = None): old = _pop_mode(k) try: yield old @@ -429,18 +429,18 @@ def to( non_blocking: bool = False, copy: bool = False, *, - memory_format: Optional[torch.memory_format] = None, + memory_format: torch.memory_format | None = None, ) -> torch.Tensor: ... @overload def to( self, - device: Optional[torch._prims_common.DeviceLikeType] = None, - dtype: Optional[torch.types._dtype] = None, + device: torch._prims_common.DeviceLikeType | None = None, + dtype: torch.types._dtype | None = None, non_blocking: bool = False, copy: bool = False, *, - memory_format: Optional[torch.memory_format] = None, + memory_format: torch.memory_format | None = None, ) -> torch.Tensor: ... @overload @@ -450,7 +450,7 @@ def to( non_blocking: bool = False, copy: bool = False, *, - memory_format: Optional[torch.memory_format] = None, + memory_format: torch.memory_format | None = None, ) -> torch.Tensor: ... @@ -610,7 +610,7 @@ def alias_non_inplace_storage(arg, ret) -> None: alias_non_inplace_storage(args[arg_idx], outs[return_idx]) -def _get_write_alias(x) -> Optional[str]: +def _get_write_alias(x) -> str | None: alias_set = x.alias_set if not alias_set or not x.is_write: return None @@ -629,7 +629,7 @@ def _get_write_alias(x) -> Optional[str]: class AliasInfo: alias_set: set[str] is_write: bool - name: Optional[str] + name: str | None @dataclass @@ -642,7 +642,7 @@ class SchemaInfo: # [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce # all-Nones result to empty list instead, and we don't support # some-but-not-all-Nones. - outs_write_aliases: Optional[list[str]] + outs_write_aliases: list[str] | None # List of (arg_idx, return_idx) where args[arg_idx].alias_set & # outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write. @@ -726,12 +726,12 @@ def get_alias_info(func) -> SchemaInfo: if is_read_only_alias_match: read_only_alias_match_indexes.append((arg_idx, return_idx)) - outs_write_aliases_list: list[Optional[str]] = [ + outs_write_aliases_list: list[str | None] = [ _get_write_alias(r) for r in out_schemas ] non_nones = sum(x is not None for x in outs_write_aliases_list) if non_nones == 0: - outs_write_aliases: Optional[list[str]] = None + outs_write_aliases: list[str] | None = None elif non_nones != len(outs_write_aliases_list): # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" raise RuntimeError("Unsupported schema: " + str(func._schema)) @@ -751,7 +751,7 @@ def get_alias_info(func) -> SchemaInfo: def autograd_would_have_decomposed( - func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]] + func: torch._ops.OpOverload, flat_args: Sequence[torch.Tensor | object] ) -> bool: """ Suppose that an operator has CompositeImplicitAutograd decomp registered. diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 147340f58d66e..3d2e4d110b6b2 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -33,7 +33,6 @@ Final, Generic, NoReturn, - Optional, overload, Protocol, TypeAlias, @@ -109,7 +108,7 @@ def get(self, parent: Any) -> Any: ... class EnumEncoder(json.JSONEncoder): - def default(self, obj: object) -> Union[str, dict[str, Any]]: + def default(self, obj: object) -> str | dict[str, Any]: if isinstance(obj, Enum): return { "__enum__": True, @@ -127,7 +126,7 @@ def default(self, obj: object) -> Union[str, dict[str, Any]]: ToDumpableContextFn = Callable[[Context], DumpableContext] FromDumpableContextFn = Callable[[DumpableContext], Context] ToStrFunc = Callable[["TreeSpec", list[str]], str] -MaybeFromStrFunc = Callable[[str], Optional[tuple[Any, Context, str]]] +MaybeFromStrFunc = Callable[[str], tuple[Any, Context, str] | None] KeyPath = tuple[KeyEntry, ...] FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] @@ -145,7 +144,7 @@ class NodeDef(NamedTuple): type: type[Any] flatten_fn: FlattenFunc unflatten_fn: UnflattenFunc - flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + flatten_with_keys_fn: FlattenWithKeysFunc | None _NODE_REGISTRY_LOCK = threading.RLock() @@ -162,8 +161,8 @@ class NodeDef(NamedTuple): class _SerializeNodeDef(NamedTuple): typ: type[Any] serialized_type_name: str - to_dumpable_context: Optional[ToDumpableContextFn] - from_dumpable_context: Optional[FromDumpableContextFn] + to_dumpable_context: ToDumpableContextFn | None + from_dumpable_context: FromDumpableContextFn | None SUPPORTED_SERIALIZED_TYPES: dict[type[Any], _SerializeNodeDef] = {} @@ -199,10 +198,10 @@ def register_pytree_node( flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, + flatten_with_keys_fn: FlattenWithKeysFunc | None = None, ) -> None: """Register a container-like type as pytree node. @@ -273,9 +272,9 @@ def register_pytree_node( def register_dataclass( cls: type[Any], *, - field_names: Optional[list[str]] = None, - drop_field_names: Optional[list[str]] = None, - serialized_type_name: Optional[str] = None, + field_names: list[str] | None = None, + drop_field_names: list[str] | None = None, + serialized_type_name: str | None = None, ) -> None: """ Registers a type that has the semantics of a ``dataclasses.dataclass`` type @@ -524,13 +523,13 @@ def _register_pytree_node( cls: type[Any], flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, - to_str_fn: Optional[ToStrFunc] = None, # deprecated - maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + to_str_fn: ToStrFunc | None = None, # deprecated + maybe_from_str_fn: MaybeFromStrFunc | None = None, # deprecated *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, + flatten_with_keys_fn: FlattenWithKeysFunc | None = None, ) -> None: """Register a container-like type as pytree node for the Python pytree only. @@ -594,10 +593,10 @@ def _private_register_pytree_node( flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc, *, - serialized_type_name: Optional[str] = None, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, - flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, + serialized_type_name: str | None = None, + to_dumpable_context: ToDumpableContextFn | None = None, + from_dumpable_context: FromDumpableContextFn | None = None, + flatten_with_keys_fn: FlattenWithKeysFunc | None = None, ) -> None: """This is an internal function that is used to register a pytree node type for the Python pytree only. End-users should use :func:`register_pytree_node` @@ -671,7 +670,7 @@ def get(self, obj: Any) -> Any: # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py -def is_namedtuple(obj: Union[object, type]) -> bool: +def is_namedtuple(obj: object | type) -> bool: """Return whether the object is an instance of namedtuple or a subclass of namedtuple.""" cls = obj if isinstance(obj, type) else type(obj) return is_namedtuple_class(cls) @@ -723,7 +722,7 @@ def __new__( # Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py -def is_structseq(obj: Union[object, type]) -> bool: +def is_structseq(obj: object | type) -> bool: """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.""" cls = obj if isinstance(obj, type) else type(obj) return is_structseq_class(cls) @@ -1046,7 +1045,7 @@ def _get_node_type(tree: Any) -> Any: # A leaf is defined as anything that is not a Node. def tree_is_leaf( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: """Check if a pytree is a leaf. @@ -1073,7 +1072,7 @@ def tree_is_leaf( "Please use torch.utils._pytree.tree_is_leaf instead.", category=FutureWarning, ) -def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: +def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> bool: return tree_is_leaf(tree, is_leaf=is_leaf) @@ -1353,7 +1352,7 @@ def treespec_tuple(iterable: Iterable[TreeSpec] = (), /) -> TreeSpec: def treespec_dict( - mapping: Union[Mapping[Any, TreeSpec], Iterable[tuple[Any, TreeSpec]]] = (), + mapping: Mapping[Any, TreeSpec] | Iterable[tuple[Any, TreeSpec]] = (), /, **kwargs: TreeSpec, ) -> TreeSpec: @@ -1366,7 +1365,7 @@ def treespec_dict( def tree_flatten( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> tuple[list[Any], TreeSpec]: """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree. @@ -1404,7 +1403,7 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: def tree_iter( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> Iterable[Any]: """Get an iterator over the leaves of a pytree.""" if tree_is_leaf(tree, is_leaf=is_leaf): @@ -1421,7 +1420,7 @@ def tree_iter( def tree_leaves( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[Any]: """Get a list of leaves of a pytree.""" return list(tree_iter(tree, is_leaf=is_leaf)) @@ -1429,7 +1428,7 @@ def tree_leaves( def tree_structure( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> TreeSpec: """Get the TreeSpec for a pytree.""" return tree_flatten(tree, is_leaf=is_leaf)[1] @@ -1439,7 +1438,7 @@ def tree_map( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Map a multi-input function over pytree args to produce a new pytree. @@ -1483,7 +1482,7 @@ def tree_map_( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. @@ -1517,8 +1516,8 @@ def tree_map_( Type3 = tuple[type[T], type[S], type[U]] TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType] -Fn2 = Callable[[Union[T, S]], R] -Fn3 = Callable[[Union[T, S, U]], R] +Fn2 = Callable[[T | S], R] +Fn3 = Callable[[T | S | U], R] Fn = Callable[[T], R] FnAny = Callable[[Any], R] @@ -1553,7 +1552,7 @@ def map_only( def map_only( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], / + type_or_types_or_pred: TypeAny | Callable[[Any], bool], / ) -> MapOnlyFn[FnAny[Any]]: """ Suppose you are writing a tree_map over tensors, leaving everything @@ -1601,7 +1600,7 @@ def tree_map_only( /, func: Fn[T, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1611,7 +1610,7 @@ def tree_map_only( /, func: Fn2[T, S, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1621,7 +1620,7 @@ def tree_map_only( /, func: Fn3[T, S, U, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1631,7 +1630,7 @@ def tree_map_only( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1641,16 +1640,16 @@ def tree_map_only( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... def tree_map_only( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + type_or_types_or_pred: TypeAny | Callable[[Any], bool], /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -1661,7 +1660,7 @@ def tree_map_only_( /, func: Fn[T, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1671,7 +1670,7 @@ def tree_map_only_( /, func: Fn2[T, S, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1681,7 +1680,7 @@ def tree_map_only_( /, func: Fn3[T, S, U, Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1691,7 +1690,7 @@ def tree_map_only_( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... @@ -1701,16 +1700,16 @@ def tree_map_only_( /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: ... def tree_map_only_( - type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + type_or_types_or_pred: TypeAny | Callable[[Any], bool], /, func: FnAny[Any], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) @@ -1718,7 +1717,7 @@ def tree_map_only_( def tree_all( pred: Callable[[Any], bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return all(map(pred, flat_args)) @@ -1727,7 +1726,7 @@ def tree_all( def tree_any( pred: Callable[[Any], bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return any(map(pred, flat_args)) @@ -1739,7 +1738,7 @@ def tree_all_only( /, pred: Fn[T, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1749,7 +1748,7 @@ def tree_all_only( /, pred: Fn2[T, S, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1759,7 +1758,7 @@ def tree_all_only( /, pred: Fn3[T, S, U, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1768,7 +1767,7 @@ def tree_all_only( /, pred: FnAny[bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return all(pred(x) for x in flat_args if isinstance(x, type_or_types)) @@ -1780,7 +1779,7 @@ def tree_any_only( /, pred: Fn[T, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1790,7 +1789,7 @@ def tree_any_only( /, pred: Fn2[T, S, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1800,7 +1799,7 @@ def tree_any_only( /, pred: Fn3[T, S, U, bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: ... @@ -1809,7 +1808,7 @@ def tree_any_only( /, pred: FnAny[bool], tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> bool: flat_args = tree_iter(tree, is_leaf=is_leaf) return any(pred(x) for x in flat_args if isinstance(x, type_or_types)) @@ -1826,8 +1825,8 @@ def tree_any_only( def _broadcast_to_and_flatten( tree: PyTree, treespec: TreeSpec, - is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> Optional[list[Any]]: + is_leaf: Callable[[PyTree], bool] | None = None, +) -> list[Any] | None: if not isinstance(treespec, TreeSpec): raise AssertionError("treespec must be a TreeSpec") @@ -1868,7 +1867,7 @@ class _TreeSpecSchema: - children_spec: A list of children serialized specs. """ - type: Optional[str] + type: str | None context: DumpableContext children_spec: list["_TreeSpecSchema"] @@ -1917,7 +1916,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) -def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]: +def enum_object_hook(obj: dict[str, Any]) -> Enum | dict[str, Any]: if "__enum__" in obj: modname, _, classname = obj["fqn"].partition(":") mod = importlib.import_module(modname) @@ -1968,7 +1967,7 @@ def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: _SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) -def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: +def treespec_dumps(treespec: TreeSpec, protocol: int | None = None) -> str: if not isinstance(treespec, TreeSpec): raise TypeError( f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " @@ -2048,7 +2047,7 @@ def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> list[Any]: def tree_flatten_with_path( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]: """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. @@ -2072,7 +2071,7 @@ def tree_flatten_with_path( def tree_leaves_with_path( tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> list[tuple[KeyPath, Any]]: """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. @@ -2094,7 +2093,7 @@ def tree_leaves_with_path( def _generate_key_paths( key_path: KeyPath, tree: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> Iterable[tuple[KeyPath, Any]]: if is_leaf and is_leaf(tree): yield key_path, tree @@ -2124,7 +2123,7 @@ def tree_map_with_path( func: Callable[..., Any], tree: PyTree, *rests: PyTree, - is_leaf: Optional[Callable[[PyTree], bool]] = None, + is_leaf: Callable[[PyTree], bool] | None = None, ) -> PyTree: """Like :func:`tree_map`, but the provided callable takes an additional key path argument. diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 47cf07552b2cf..d2e1595bf2a14 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -8,7 +8,7 @@ import time from collections.abc import Callable, Sequence from threading import Lock -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec @@ -34,14 +34,14 @@ class StrobelightCLIProfilerError(Exception): """ -def _pid_namespace_link(pid: Optional[int] = None) -> str: +def _pid_namespace_link(pid: int | None = None) -> str: """Returns the link to the process's namespace, example: pid:[4026531836]""" PID_NAMESPACE_PATH = "/proc/{}/ns/pid" pid = pid or os.getpid() return os.readlink(PID_NAMESPACE_PATH.format(pid)) -def _pid_namespace(pid: Optional[int] = None) -> int: +def _pid_namespace(pid: int | None = None) -> int: """Returns the process's namespace id""" pid = pid or os.getpid() link = _pid_namespace_link(pid) @@ -77,8 +77,8 @@ def __init__( run_user_name: str = "pytorch-strobelight-ondemand", timeout_wait_for_running_sec: int = 60, timeout_wait_for_finished_sec: int = 60, - recorded_env_variables: Optional[list[str]] = None, - sample_tags: Optional[list[str]] = None, + recorded_env_variables: list[str] | None = None, + sample_tags: list[str] | None = None, stack_max_len: int = 127, async_stack_max_len: int = 127, ) -> None: @@ -90,7 +90,7 @@ def __init__( self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec # Results of the most recent run. # Tracks the strobelight run id of the most recent run - self.current_run_id: Optional[int] = None + self.current_run_id: int | None = None self.sample_tags = sample_tags def _run_async(self) -> None: @@ -253,7 +253,7 @@ def _start_strobelight(self) -> bool: def profile( self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs - ) -> Optional[_R]: + ) -> _R | None: self.current_run_id = None if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): @@ -295,16 +295,16 @@ def profile( # @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) # @strobelight(stop_at_error=True,...) def strobelight( - profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any -) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]: + profiler: StrobelightCLIFunctionProfiler | None = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, _R | None]]: if not profiler: profiler = StrobelightCLIFunctionProfiler(**kwargs) def strobelight_inner( work_function: Callable[_P, _R], - ) -> Callable[_P, Optional[_R]]: + ) -> Callable[_P, _R | None]: @functools.wraps(work_function) - def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _R | None: # pyrefly: ignore [bad-argument-type] return profiler.profile(work_function, *args, **kwargs) diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 425344bda17ef..fee27cfb0cbe7 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -4,7 +4,7 @@ import operator import sys from collections.abc import Callable -from typing import Optional, SupportsFloat, TYPE_CHECKING, TypeVar, Union +from typing import SupportsFloat, TYPE_CHECKING, TypeVar from typing_extensions import TypeVarTuple, Unpack import sympy @@ -102,11 +102,11 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool: def _keep_float( f: Callable[[Unpack[_Ts]], _T], -) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: +) -> Callable[[Unpack[_Ts]], _T | sympy.Float]: @functools.wraps(f) - def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: + def inner(*args: Unpack[_Ts]) -> _T | sympy.Float: # pyrefly: ignore [bad-argument-type] - r: Union[_T, sympy.Float] = f(*args) + r: _T | sympy.Float = f(*args) if any(isinstance(a, sympy.Float) for a in args) and not isinstance( r, sympy.Float ): @@ -117,7 +117,7 @@ def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: return inner -def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]: +def fuzzy_eq(x: bool | None, y: bool | None) -> bool | None: if None in (x, y): return None return x == y @@ -216,9 +216,7 @@ def _sympystr(self, printer: sympy.printing.StrPrinter) -> str: # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod - def eval( - cls, base: sympy.Integer, divisor: sympy.Integer - ) -> Union[sympy.Basic, None]: + def eval(cls, base: sympy.Integer, divisor: sympy.Integer) -> sympy.Basic | None: # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full # Assert triggered by inequality solver # assert base.is_integer, base @@ -324,7 +322,7 @@ class ModularIndexing(sympy.Function): @classmethod def eval( cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer - ) -> Optional[sympy.Basic]: + ) -> sympy.Basic | None: if base == 0 or modulus == 1: return sympy.S.Zero if ( @@ -373,7 +371,7 @@ def eval( return None - def _eval_is_nonnegative(self) -> Optional[bool]: + def _eval_is_nonnegative(self) -> bool | None: # pyrefly: ignore [missing-attribute] p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] @@ -387,23 +385,21 @@ class Where(sympy.Function): nargs: tuple[int, ...] = (3,) precedence: int = 35 # lower precedence than add - def _eval_is_integer(self) -> Optional[bool]: + def _eval_is_integer(self) -> bool | None: return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] - def _eval_is_nonnegative(self) -> Optional[bool]: + def _eval_is_nonnegative(self) -> bool | None: return ( True if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] else None ) - def _eval_is_positive(self) -> Optional[bool]: + def _eval_is_positive(self) -> bool | None: return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] @classmethod - def eval( - cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic - ) -> Optional[sympy.Basic]: + def eval(cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic) -> sympy.Basic | None: if c == sympy.true: return p elif c == sympy.false: @@ -419,7 +415,7 @@ class PythonMod(sympy.Function): is_integer: bool = True @classmethod - def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: + def eval(cls, p: sympy.Expr, q: sympy.Expr) -> sympy.Expr | None: # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint # Triggered by sympy.solvers.inequalities.reduce_inequalities # assert p.is_integer, p @@ -465,10 +461,10 @@ def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: return None # NB: args[1] for PythonMod - def _eval_is_nonnegative(self) -> Optional[bool]: + def _eval_is_nonnegative(self) -> bool | None: return True if self.args[1].is_positive else None # type: ignore[attr-defined] - def _eval_is_nonpositive(self) -> Optional[bool]: + def _eval_is_nonpositive(self) -> bool | None: return True if self.args[1].is_negative else None # type: ignore[attr-defined] def _ccode(self, printer) -> str: @@ -664,7 +660,7 @@ def __new__(cls, *original_args, **assumptions): @classmethod def _satisfy_unique_summations_symbols( cls, args - ) -> Optional[set[sympy.core.symbol.Symbol]]: + ) -> set[sympy.core.symbol.Symbol] | None: """ One common case in some models is building expressions of the form max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...). @@ -719,8 +715,8 @@ def _satisfy_unique_summations_symbols( @classmethod def _unique_symbols( - cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None - ) -> Optional[set[sympy.core.symbol.Symbol]]: + cls, args, initial_set: set[sympy.core.symbol.Symbol] | None = None + ) -> set[sympy.core.symbol.Symbol] | None: """ Return seen_symbols if all atoms in all args are all unique symbols, else returns None. initial_set can be used to represent initial value for seen_symbols diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 6dc496a0ddb13..dd8cacc053950 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -10,7 +10,7 @@ import functools import logging -from typing import Any, Union +from typing import Any import sympy from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom @@ -184,7 +184,7 @@ def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): def sympy_interp( analysis, env: dict[sympy.Symbol, Any], - expr: Union[sympy.Expr, SympyBoolean], + expr: sympy.Expr | SympyBoolean, *, index_dtype=torch.int64, missing_handler=None, diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 915d0e5461f1e..4078ae95315f9 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -1,5 +1,4 @@ import sys -from typing import Optional import sympy from sympy.printing.precedence import PRECEDENCE, precedence @@ -23,7 +22,7 @@ def _print_Mul(self, expr: sympy.Expr) -> str: def _print_Not(self, expr: sympy.Expr) -> str: return f"not ({self._print(expr.args[0])})" - def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: + def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str: return self.stringify(expr.args, " + ", precedence(expr)) def _print_Relational(self, expr: sympy.Expr) -> str: @@ -310,7 +309,7 @@ def _print_Piecewise(self, expr: sympy.Expr) -> str: # Convert Piecewise(expr_cond_pairs) to nested ternary expressions # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) # becomes: e1 if c1 else (e2 if c2 else (... else eN)) - result: Optional[str] = None + result: str | None = None for expr_i, cond_i in reversed(expr.args): expr_str = self._print(expr_i) if cond_i == True: # noqa: E712 @@ -349,7 +348,7 @@ def _print_Piecewise(self, expr: sympy.Expr) -> str: # Convert Piecewise(expr_cond_pairs) to nested ternary operators # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) - result: Optional[str] = None + result: str | None = None for expr_i, cond_i in reversed(expr.args): expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) if cond_i == True: # noqa: E712 diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index e9b4a91429a4d..874de07e6ca7d 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import math import operator -from typing import NoReturn, Union +from typing import NoReturn import sympy @@ -359,7 +359,7 @@ class TensorReferenceAnalysis: # function isn't traced correctly. Here for completeness. @staticmethod def constant(c, dtype): - d: Union[int, float, bool] + d: int | float | bool if dtype is torch.int64: d = int(c) elif dtype is torch.double: diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 840957f4109cb..3bd5e1484601f 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import sympy @@ -20,7 +19,7 @@ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le) -def mirror_rel_op(type: type) -> Optional[type[sympy.Rel]]: +def mirror_rel_op(type: type) -> type[sympy.Rel] | None: return _MIRROR_REL_OP.get(type) @@ -43,7 +42,7 @@ def try_solve( thing: sympy.Basic, trials: int = 5, floordiv_inequality: bool = True, -) -> Optional[tuple[sympy.Rel, sympy.Expr]]: +) -> tuple[sympy.Rel, sympy.Expr] | None: mirror = mirror_rel_op(type(expr)) # Ignore unsupported expressions: diff --git a/torch/utils/_sympy/symbol.py b/torch/utils/_sympy/symbol.py index cd25478e6ed18..61a7c147458e0 100644 --- a/torch/utils/_sympy/symbol.py +++ b/torch/utils/_sympy/symbol.py @@ -14,7 +14,6 @@ from collections.abc import Iterable from enum import auto, Enum -from typing import Union import sympy @@ -88,7 +87,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: # This type is a little wider than it should be, because free_symbols says # that it contains Basic, rather than Symbol -def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool: +def symbol_is_type(sym: sympy.Basic, prefix: SymT | Iterable[SymT]) -> bool: if not isinstance(sym, sympy.Symbol): raise AssertionError("expected sympy.Symbol") name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK @@ -98,5 +97,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> boo return name_str.startswith(tuple(prefix_str[p] for p in prefix)) -def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool: +def free_symbol_is_type(e: sympy.Expr, prefix: SymT | Iterable[SymT]) -> bool: return any(symbol_is_type(v, prefix) for v in e.free_symbols) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index ef7c1696480b5..2016203ece67c 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -10,7 +10,6 @@ from collections.abc import Callable from typing import ( Generic, - Optional, overload, SupportsFloat, TYPE_CHECKING, @@ -325,16 +324,16 @@ def unknown_bool() -> ValueRanges[SympyBoolean]: @overload @staticmethod # work around the fact that bool and int overlap - def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR: # type: ignore[overload-overlap] + def wrap(arg: ExprIn | ExprVR) -> ExprVR: # type: ignore[overload-overlap] ... @overload @staticmethod - def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR: # type: ignore[misc] + def wrap(arg: BoolIn | BoolVR) -> BoolVR: # type: ignore[misc] ... @staticmethod - def wrap(arg: Union[AllIn, AllVR]) -> AllVR: + def wrap(arg: AllIn | AllVR) -> AllVR: if isinstance(arg, ValueRanges): return arg if isinstance(arg, float) and math.isnan(arg): @@ -343,29 +342,29 @@ def wrap(arg: Union[AllIn, AllVR]) -> AllVR: return ValueRanges(arg, arg) # type: ignore[arg-type] @staticmethod - def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + def increasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: """Increasing: x <= y => f(x) <= f(y).""" x = ValueRanges.wrap(x) return ValueRanges(fn(x.lower), fn(x.upper)) @overload @staticmethod - def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: ... + def decreasing_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: ... @overload @staticmethod - def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR: # type: ignore[misc] + def decreasing_map(x: BoolIn | BoolVR, fn: BoolFn) -> BoolVR: # type: ignore[misc] ... @staticmethod - def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR: + def decreasing_map(x: AllIn | AllVR, fn: AllFn) -> AllVR: """Decreasing: x <= y => f(x) >= f(y).""" x = ValueRanges.wrap(x) # consistently either Expr or Bool, but we don't know it here return ValueRanges(fn(x.upper), fn(x.lower)) # type: ignore[arg-type] @staticmethod - def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + def monotone_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: """It's increasing or decreasing.""" x = ValueRanges.wrap(x) l = fn(x.lower) @@ -373,7 +372,7 @@ def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: return ValueRanges(min(l, u), max(l, u)) @staticmethod - def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: + def convex_min_zero_map(x: ExprIn | ExprVR, fn: ExprFn) -> ExprVR: """Fn is convex and has a minimum at 0.""" x = ValueRanges.wrap(x) if 0 in x: @@ -387,23 +386,23 @@ def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR: @overload @staticmethod def coordinatewise_increasing_map( - x: Union[ExprIn, ExprVR], - y: Union[ExprIn, ExprVR], + x: ExprIn | ExprVR, + y: ExprIn | ExprVR, fn: ExprFn2, ) -> ExprVR: ... @overload @staticmethod def coordinatewise_increasing_map( # type: ignore[misc] - x: Union[BoolIn, BoolVR], - y: Union[BoolIn, BoolVR], + x: BoolIn | BoolVR, + y: BoolIn | BoolVR, fn: BoolFn2, ) -> BoolVR: ... @staticmethod def coordinatewise_increasing_map( - x: Union[AllIn, AllVR], - y: Union[AllIn, AllVR], + x: AllIn | AllVR, + y: AllIn | AllVR, fn: AllFn2, ) -> AllVR: """ @@ -1037,7 +1036,7 @@ def trunc(x): def bound_sympy( - expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None + expr: sympy.Expr, ranges: dict[sympy.Symbol, ValueRanges] | None = None ) -> ValueRanges: log.debug( "bound_sympy(%s)%s", diff --git a/torch/utils/_thunk.py b/torch/utils/_thunk.py index a332babfdf4ce..b5ab598077f4e 100644 --- a/torch/utils/_thunk.py +++ b/torch/utils/_thunk.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import Generic, Optional, TypeVar +from typing import Generic, TypeVar R = TypeVar("R") @@ -12,8 +12,8 @@ class Thunk(Generic[R]): function once it is forced. """ - f: Optional[Callable[[], R]] - r: Optional[R] + f: Callable[[], R] | None + r: R | None __slots__ = ["f", "r"] diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index 39a302ea5ca25..f5415002092a2 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -5,7 +5,6 @@ import tempfile import traceback from types import TracebackType -from typing import Optional # This file contains utilities for ensuring dynamically compile()'d @@ -234,7 +233,7 @@ def format_all(tbs): import torch._C._profiler # Directly populate tracebacks that already have cached summaries - rs: list[Optional[list[str]]] = [] + rs: list[list[str] | None] = [] delayed_idxs = [] for i, tb in enumerate(tbs): if tb.tb is None: diff --git a/torch/utils/_typing_utils.py b/torch/utils/_typing_utils.py index ffb6b383e4e6b..f28c9f94100b7 100644 --- a/torch/utils/_typing_utils.py +++ b/torch/utils/_typing_utils.py @@ -1,6 +1,6 @@ """Miscellaneous utilities to aid with typing.""" -from typing import Optional, TypeVar +from typing import TypeVar # Helper to turn Optional[T] into T when we know None either isn't @@ -8,7 +8,7 @@ T = TypeVar("T") -def not_none(obj: Optional[T]) -> T: +def not_none(obj: T | None) -> T: if obj is None: raise TypeError("Invariant encountered: value was None when it should not be") return obj diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index b31eb49a60601..2300306d22d2d 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional, Union import torch from torch._C import _get_privateuse1_backend_name, _rename_privateuse1_backend @@ -90,7 +89,7 @@ def _check_register_once(module, attr) -> None: def _normalization_device( - custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None + custom_backend_name: str, device: int | str | torch.device | None = None ) -> int: def _get_current_device_index(): _get_device_index = "current_device" @@ -137,7 +136,7 @@ def wrap_tensor_backend(self: torch.Tensor) -> bool: def wrap_tensor_to( self: torch.Tensor, - device: Optional[Union[int, torch.device]] = None, + device: int | torch.device | None = None, non_blocking=False, **kwargs, ) -> torch.Tensor: @@ -188,7 +187,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) - def wrap_module_to( self: torch.nn.modules.module.T, - device: Optional[Union[int, torch.device]] = None, + device: int | torch.device | None = None, ) -> torch.nn.modules.module.T: r"""Move all model parameters and buffers to the custom device. @@ -268,7 +267,7 @@ def wrap_module_to( def _generate_storage_methods_for_privateuse1_backend( - custom_backend_name: str, unsupported_dtype: Optional[list[torch.dtype]] = None + custom_backend_name: str, unsupported_dtype: list[torch.dtype] | None = None ) -> None: # Attribute is registered in the _StorageBase class # and UntypedStorage obtains through inheritance. @@ -355,7 +354,7 @@ def generate_methods_for_privateuse1_backend( for_module: bool = True, for_packed_sequence: bool = True, for_storage: bool = False, - unsupported_dtype: Optional[list[torch.dtype]] = None, + unsupported_dtype: list[torch.dtype] | None = None, ) -> None: r""" Automatically generate attributes and methods for the custom backend after rename privateuse1 backend. diff --git a/torch/utils/benchmark/op_fuzzers/sparse_unary.py b/torch/utils/benchmark/op_fuzzers/sparse_unary.py index 07d2aeeeabaf2..18921becd078c 100644 --- a/torch/utils/benchmark/op_fuzzers/sparse_unary.py +++ b/torch/utils/benchmark/op_fuzzers/sparse_unary.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np import torch @@ -20,7 +20,7 @@ )) class UnaryOpSparseFuzzer(Fuzzer): - def __init__(self, seed: Optional[int], dtype: _dtype | None = None, cuda: bool = False) -> None: + def __init__(self, seed: int | None, dtype: _dtype | None = None, cuda: bool = False) -> None: if dtype is None: dtype = getattr(torch, 'float32', None) super().__init__( diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py index 10fe1d898de0f..d4f328d19083f 100644 --- a/torch/utils/benchmark/utils/common.py +++ b/torch/utils/benchmark/utils/common.py @@ -8,7 +8,7 @@ import tempfile import textwrap import time -from typing import cast, Any, Optional +from typing import cast, Any from collections.abc import Iterable, Iterator import uuid @@ -34,10 +34,10 @@ class TaskSpec: stmt: str setup: str global_setup: str = "" - label: Optional[str] = None - sub_label: Optional[str] = None - description: Optional[str] = None - env: Optional[str] = None + label: str | None = None + sub_label: str | None = None + description: str | None = None + env: str | None = None num_threads: int = 1 @property @@ -82,7 +82,7 @@ class Measurement: number_per_run: int raw_times: list[float] task_spec: TaskSpec - metadata: Optional[dict[Any, Any]] = None # Reserved for user payloads. + metadata: dict[Any, Any] | None = None # Reserved for user payloads. def __post_init__(self) -> None: self._sorted_times: tuple[float, ...] = () @@ -297,7 +297,7 @@ def set_torch_threads(n: int) -> Iterator[None]: torch.set_num_threads(prior_num_threads) -def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str: +def _make_temp_dir(prefix: str | None = None, gc_dev_shm: bool = False) -> str: """Create a temporary directory. The caller is responsible for cleanup. This function is conceptually similar to `tempfile.mkdtemp`, but with diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index e9a0966c6e966..c1e232e6e0426 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -3,7 +3,6 @@ import collections import enum import itertools as it -from typing import Optional from torch.utils.benchmark.utils import common from torch import tensor as _tensor @@ -29,7 +28,7 @@ class Colorize(enum.Enum): class _Column: def __init__( self, - grouped_results: list[tuple[Optional[common.Measurement], ...]], + grouped_results: list[tuple[common.Measurement | None, ...]], time_scale: float, time_unit: str, trim_significant_figures: bool, @@ -60,7 +59,7 @@ def __init__( def get_results_for(self, group): return self._grouped_results[group] - def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]): + def num_to_str(self, value: float | None, estimated_sigfigs: int, spread: float | None): if value is None: return " " * len(self.num_to_str(1, estimated_sigfigs, None)) @@ -175,17 +174,17 @@ def __init__( self.rows, self.columns = self.populate_rows_and_columns() @staticmethod - def row_fn(m: common.Measurement) -> tuple[int, Optional[str], str]: + def row_fn(m: common.Measurement) -> tuple[int, str | None, str]: return m.num_threads, m.env, m.as_row_name @staticmethod - def col_fn(m: common.Measurement) -> Optional[str]: + def col_fn(m: common.Measurement) -> str | None: return m.description def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, ...]]: rows: list[_Row] = [] columns: list[_Column] = [] - ordered_results: list[list[Optional[common.Measurement]]] = [ + ordered_results: list[list[common.Measurement | None]] = [ [None for _ in self.column_keys] for _ in self.row_keys ] @@ -205,7 +204,7 @@ def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, .. prior_num_threads = -1 prior_env = "" row_group = -1 - rows_by_group: list[list[list[Optional[common.Measurement]]]] = [] + rows_by_group: list[list[list[common.Measurement | None]]] = [] for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True): thread_transition = (num_threads != prior_num_threads) if thread_transition: diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index d8881354ddaf2..dd15a582a2749 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, cast, Optional, Union +from typing import Any, cast from collections.abc import Callable import torch @@ -40,11 +40,11 @@ def _disable_tensor_cores() -> None: torch.set_float32_matmul_precision(_default_float_32_precision) def bench_loop( - model: Union[torch.nn.Module, Callable], - sample_input: Union[torch.Tensor, Any], + model: torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, num_iters: int = 5, - optimizer: Optional[torch.optim.Optimizer] = None, - loss_fn: Optional[Callable] = None, + optimizer: torch.optim.Optimizer | None = None, + loss_fn: Callable | None = None, ): # Define the statement and setup for the benchmark if optimizer and loss_fn: @@ -74,13 +74,13 @@ def bench_loop( return round(avg_time, 2) def benchmark_compile( - model: Union[torch.nn.Module, Callable], - sample_input: Union[torch.Tensor, Any], + model: torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, num_iters: int = 5, - backend: Optional[str] = None, - mode: Optional[str] = "default", - optimizer: Optional[torch.optim.Optimizer] = None, - loss_fn : Union[torch.nn.Module, Callable, None] = None, + backend: str | None = None, + mode: str | None = "default", + optimizer: torch.optim.Optimizer | None = None, + loss_fn : torch.nn.Module | Callable | None = None, ): """ Use this utility to benchmark torch.compile @@ -119,11 +119,11 @@ def benchmark_compile( def bench_all( - model : Union[torch.nn.Module, Callable], - sample_input: Union[torch.Tensor, Any], + model : torch.nn.Module | Callable, + sample_input: torch.Tensor | Any, num_iters : int = 5, - optimizer: Optional[torch.optim.Optimizer] = None, - loss_fn : Union[torch.nn.Module, Callable, None] = None, + optimizer: torch.optim.Optimizer | None = None, + loss_fn : torch.nn.Module | Callable | None = None, ): """ This is a simple utility that can be used to benchmark torch.compile @@ -155,7 +155,7 @@ def bench_all( for backend in torch._dynamo.list_backends(): if backend == "inductor": - mode_options = cast(list[Optional[str]], list(torch._inductor.list_mode_options().keys())) + [None] + mode_options = cast(list[str | None], list(torch._inductor.list_mode_options().keys())) + [None] for mode in mode_options: if mode == "default": continue diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py index 969eb6abb6954..a298146ce17c7 100644 --- a/torch/utils/benchmark/utils/cpp_jit.py +++ b/torch/utils/benchmark/utils/cpp_jit.py @@ -5,7 +5,7 @@ import shutil import textwrap import threading -from typing import Any, Optional +from typing import Any import torch from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType @@ -29,7 +29,7 @@ # ```` # `setup` and `stmt` do not change, so we can reuse the executable from the # first pass through the loop. -_BUILD_ROOT: Optional[str] = None +_BUILD_ROOT: str | None = None def _get_build_root() -> str: global _BUILD_ROOT @@ -64,7 +64,7 @@ def _get_build_root() -> str: # analysis and the shims no longer justify their maintenance and code # complexity costs) back testing paths will be removed. -CXX_FLAGS: Optional[list[str]] +CXX_FLAGS: list[str] | None if hasattr(torch.__config__, "_cxx_flags"): try: CXX_FLAGS = torch.__config__._cxx_flags().strip().split() @@ -89,7 +89,7 @@ def _get_build_root() -> str: EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include")) -COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None +COMPAT_CALLGRIND_BINDINGS: CallgrindModuleType | None = None def get_compat_bindings() -> CallgrindModuleType: with LOCK: global COMPAT_CALLGRIND_BINDINGS diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index 06f37bd8f3a35..38f771d23632e 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import itertools as it -from typing import Any, Optional, Union +from typing import Any from collections.abc import Callable import torch @@ -25,9 +25,9 @@ class FuzzedParameter: def __init__( self, name: str, - minval: Optional[Union[int, float]] = None, - maxval: Optional[Union[int, float]] = None, - distribution: Optional[Union[str, dict[Any, float]]] = None, + minval: int | float | None = None, + maxval: int | float | None = None, + distribution: str | dict[Any, float] | None = None, strict: bool = False, ) -> None: """ @@ -188,17 +188,17 @@ class FuzzedTensor: def __init__( self, name: str, - size: tuple[Union[str, int], ...], - steps: Optional[tuple[Union[str, int], ...]] = None, + size: tuple[str | int, ...], + steps: tuple[str | int, ...] | None = None, probability_contiguous: float = 0.5, - min_elements: Optional[int] = None, - max_elements: Optional[int] = None, - max_allocation_bytes: Optional[int] = None, - dim_parameter: Optional[str] = None, - roll_parameter: Optional[str] = None, + min_elements: int | None = None, + max_elements: int | None = None, + max_allocation_bytes: int | None = None, + dim_parameter: str | None = None, + roll_parameter: str | None = None, dtype=torch.float32, cuda=False, - tensor_constructor: Optional[Callable] = None + tensor_constructor: Callable | None = None ) -> None: """ Args: @@ -353,10 +353,10 @@ def nullable_greater(left, right): class Fuzzer: def __init__( self, - parameters: list[Union[FuzzedParameter, list[FuzzedParameter]]], - tensors: list[Union[FuzzedTensor, list[FuzzedTensor]]], - constraints: Optional[list[Callable]] = None, - seed: Optional[int] = None + parameters: list[FuzzedParameter | list[FuzzedParameter]], + tensors: list[FuzzedTensor | list[FuzzedTensor]], + constraints: list[Callable] | None = None, + seed: int | None = None ) -> None: """ Args: @@ -422,9 +422,9 @@ def rejection_rate(self): return self._rejections / self._total_generated def _generate(self, state): - strict_params: dict[str, Union[float, int, ParameterAlias]] = {} + strict_params: dict[str, float | int | ParameterAlias] = {} for _ in range(1000): - candidate_params: dict[str, Union[float, int, ParameterAlias]] = {} + candidate_params: dict[str, float | int | ParameterAlias] = {} for p in self._parameters: if p.strict: if p.name in strict_params: diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index 49afb5ea9ad06..a2a573b9b44fd 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -from typing import Optional, Union from numbers import Number import torch from torch.utils.benchmark import FuzzedTensor @@ -9,14 +8,14 @@ class FuzzedSparseTensor(FuzzedTensor): def __init__( self, name: str, - size: tuple[Union[str, int], ...], - min_elements: Optional[int] = None, - max_elements: Optional[int] = None, - dim_parameter: Optional[str] = None, - sparse_dim: Optional[str] = None, - nnz: Optional[str] = None, - density: Optional[str] = None, - coalesced: Optional[str] = None, + size: tuple[str | int, ...], + min_elements: int | None = None, + max_elements: int | None = None, + dim_parameter: str | None = None, + sparse_dim: str | None = None, + nnz: str | None = None, + density: str | None = None, + coalesced: str | None = None, dtype=torch.float32, cuda=False ) -> None: diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index 09dbb4b5a0863..f131261b8f36d 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -2,7 +2,7 @@ import enum import timeit import textwrap -from typing import overload, Any, NoReturn, Optional, Union +from typing import overload, Any, NoReturn from collections.abc import Callable import torch @@ -52,7 +52,7 @@ def __init__( self._stmt: str = textwrap.dedent(stmt) self._setup: str = textwrap.dedent(setup) self._global_setup: str = textwrap.dedent(global_setup) - self._timeit_module: Optional[TimeitModuleType] = None + self._timeit_module: TimeitModuleType | None = None def timeit(self, number: int) -> float: if self._timeit_module is None: @@ -181,13 +181,13 @@ def __init__( setup: str = "pass", global_setup: str = "", timer: Callable[[], float] = timer, - globals: Optional[dict[str, Any]] = None, - label: Optional[str] = None, - sub_label: Optional[str] = None, - description: Optional[str] = None, - env: Optional[str] = None, + globals: dict[str, Any] | None = None, + label: str | None = None, + sub_label: str | None = None, + description: str | None = None, + env: str | None = None, num_threads: int = 1, - language: Union[Language, str] = Language.PYTHON, + language: Language | str = Language.PYTHON, ) -> None: if not isinstance(stmt, str): raise ValueError("Currently only a `str` stmt is supported.") @@ -277,7 +277,7 @@ def timeit(self, number: int = 1000000) -> common.Measurement: def repeat(self, repeat: int = -1, number: int = -1) -> None: raise NotImplementedError("See `Timer.blocked_autorange.`") - def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None: + def autorange(self, callback: Callable[[int, float], NoReturn] | None = None) -> None: raise NotImplementedError("See `Timer.blocked_autorange.`") def _threaded_measurement_loop( @@ -286,8 +286,8 @@ def _threaded_measurement_loop( time_hook: Callable[[], float], stop_hook: Callable[[list[float]], bool], min_run_time: float, - max_run_time: Optional[float] = None, - callback: Optional[Callable[[int, float], NoReturn]] = None + max_run_time: float | None = None, + callback: Callable[[int, float], NoReturn] | None = None ) -> list[float]: total_time = 0.0 can_stop = False @@ -325,7 +325,7 @@ def _estimate_block_size(self, min_run_time: float) -> int: def blocked_autorange( self, - callback: Optional[Callable[[int, float], NoReturn]] = None, + callback: Callable[[int, float], NoReturn] | None = None, min_run_time: float = 0.2, ) -> common.Measurement: """Measure many replicates while keeping timer overhead to a minimum. @@ -389,7 +389,7 @@ def adaptive_autorange( *, min_run_time: float = 0.01, max_run_time: float = 10.0, - callback: Optional[Callable[[int, float], NoReturn]] = None, + callback: Callable[[int, float], NoReturn] | None = None, ) -> common.Measurement: """Similar to `blocked_autorange` but also checks for variablility in measurements and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached. @@ -472,7 +472,7 @@ def collect_callgrind( self, number: int = 100, *, - repeats: Optional[int] = None, + repeats: int | None = None, collect_baseline: bool = True, retain_out_file: bool = False, ) -> Any: diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index ef9c1936b3570..f38363f6dea89 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -12,7 +12,7 @@ import textwrap from typing import ( cast, Any, NamedTuple, - Optional, Union, TYPE_CHECKING) + Union, TYPE_CHECKING) from collections.abc import Callable from collections.abc import Iterator @@ -55,7 +55,7 @@ class FunctionCounts: # For normal use, torch._tensor_str.PRINT_OPTS.linewidth determines # the print settings. This is simply to allow hermetic unit tests. - _linewidth: Optional[int] = None + _linewidth: int | None = None def __iter__(self) -> Iterator[FunctionCount]: yield from self._data @@ -64,7 +64,7 @@ def __len__(self) -> int: return len(self._data) def __getitem__(self, item: Any) -> Union[FunctionCount, "FunctionCounts"]: - data: Union[FunctionCount, tuple[FunctionCount, ...]] = self._data[item] + data: FunctionCount | tuple[FunctionCount, ...] = self._data[item] return ( FunctionCounts(cast(tuple[FunctionCount, ...], data), self.inclusive, truncate_rows=False) if isinstance(data, tuple) else data @@ -105,7 +105,7 @@ def __sub__( ) -> "FunctionCounts": return self._merge(other, operator.neg) - def __mul__(self, other: Union[int, float]) -> "FunctionCounts": + def __mul__(self, other: int | float) -> "FunctionCounts": return self._from_dict({ fn: int(c * other) for c, fn in self._data }, self.inclusive) @@ -178,7 +178,7 @@ class CallgrindStats: baseline_exclusive_stats: FunctionCounts stmt_inclusive_stats: FunctionCounts stmt_exclusive_stats: FunctionCounts - stmt_callgrind_out: Optional[str] + stmt_callgrind_out: str | None def __repr__(self) -> str: base_stats = self.baseline_exclusive_stats @@ -311,11 +311,11 @@ class CopyIfCallgrind: See `GlobalsBridge` for why this matters. """ - def __init__(self, value: Any, *, setup: Optional[str] = None) -> None: + def __init__(self, value: Any, *, setup: str | None = None) -> None: for method, supported_types in _GLOBALS_ALLOWED_TYPES.items(): if any(isinstance(value, t) for t in supported_types): self._value: Any = value - self._setup: Optional[str] = setup + self._setup: str | None = setup self._serialization: Serialization = method break else: @@ -334,7 +334,7 @@ def value(self) -> Any: return self._value @property - def setup(self) -> Optional[str]: + def setup(self) -> str | None: return self._setup @property @@ -485,7 +485,7 @@ def construct(self) -> str: class _ValgrindWrapper: def __init__(self) -> None: - self._bindings_module: Optional[CallgrindModuleType] = None + self._bindings_module: CallgrindModuleType | None = None valgrind_symbols = ( "_valgrind_supported_platform", "_valgrind_toggle", @@ -511,7 +511,7 @@ def __init__(self) -> None: check=False, ).returncode - self._build_type: Optional[str] = None + self._build_type: str | None = None build_search = re.search("BUILD_TYPE=(.+),", torch.__config__.show()) # type: ignore[no-untyped-call] if build_search is not None: self._build_type = build_search.groups()[0].split(",")[0] @@ -576,7 +576,7 @@ def _invoke( collect_baseline: bool, is_python: bool, retain_out_file: bool, - ) -> tuple[tuple[FunctionCounts, FunctionCounts, Optional[str]], ...]: + ) -> tuple[tuple[FunctionCounts, FunctionCounts, str | None], ...]: """Core invocation method for Callgrind collection. Valgrind operates by effectively replacing the CPU with an emulated @@ -732,7 +732,7 @@ class ScanState(enum.Enum): raise AssertionError(f"Failed to parse {fpath}") return FunctionCounts(tuple(sorted(fn_counts, reverse=True)), inclusive=inclusive) - def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, Optional[str]]: + def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, str | None]: if i == repeats and not collect_baseline: # Null baseline. return ( @@ -742,7 +742,7 @@ def read_results(i: int) -> tuple[FunctionCounts, FunctionCounts, Optional[str]] ) fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files. - callgrind_out_contents: Optional[str] = None + callgrind_out_contents: str | None = None if retain_out_file: with open(fpath) as f: callgrind_out_contents = f.read() @@ -767,7 +767,7 @@ def _construct_script( collect_baseline: bool, error_log: str, stat_log: str, - bindings: Optional[CallgrindModuleType], + bindings: CallgrindModuleType | None, ) -> str: def block_stmt(stmt: str, indent: int = 0) -> str: """Partially unroll benchmark loop. @@ -914,7 +914,7 @@ def check_result(completed_process): ) -CALLGRIND_SINGLETON: Optional[_ValgrindWrapper] = None +CALLGRIND_SINGLETON: _ValgrindWrapper | None = None def wrapper_singleton() -> _ValgrindWrapper: global CALLGRIND_SINGLETON if CALLGRIND_SINGLETON is None: diff --git a/torch/utils/bundled_inputs.py b/torch/utils/bundled_inputs.py index ccb56172a077b..e91129a03864b 100644 --- a/torch/utils/bundled_inputs.py +++ b/torch/utils/bundled_inputs.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # mypy: allow-untyped-defs -from typing import Any, TypeVar, Optional, NamedTuple, Union +from typing import Any, TypeVar, NamedTuple from collections.abc import Callable, Sequence import textwrap import torch @@ -40,10 +40,10 @@ class InflatableArg(NamedTuple): def bundle_inputs( model: torch.jit.ScriptModule, - inputs: Union[Optional[Sequence[tuple[Any, ...]]], dict[Callable, Optional[Sequence[tuple[Any, ...]]]]], - info: Optional[Union[list[str], dict[Callable, list[str]]]] = None, + inputs: Sequence[tuple[Any, ...]] | None | dict[Callable, Sequence[tuple[Any, ...]] | None], + info: list[str] | dict[Callable, list[str]] | None = None, *, - _receive_inflate_expr: Optional[list[str]] = None, + _receive_inflate_expr: list[str] | None = None, ) -> torch.jit.ScriptModule: """Create and return a copy of the specified model with inputs attached. @@ -130,9 +130,9 @@ def bundle_inputs( def augment_model_with_bundled_inputs( model: torch.jit.ScriptModule, - inputs: Optional[Sequence[tuple[Any, ...]]] = None, - _receive_inflate_expr: Optional[list[str]] = None, # For debugging. - info: Optional[list[str]] = None, # Optional argument to provide info about forward or its inputs + inputs: Sequence[tuple[Any, ...]] | None = None, + _receive_inflate_expr: list[str] | None = None, # For debugging. + info: list[str] | None = None, # Optional argument to provide info about forward or its inputs skip_size_check=False, ) -> None: """Add bundled sample inputs to a model for the forward function. @@ -184,9 +184,9 @@ def augment_model_with_bundled_inputs( def augment_many_model_functions_with_bundled_inputs( model: torch.jit.ScriptModule, - inputs: dict[Callable, Optional[Sequence[tuple[Any, ...]]]], - _receive_inflate_expr: Optional[list[str]] = None, # For debugging. - info: Optional[dict[Callable, list[str]]] = None, # Optional argument to provide info about the function or its inputs + inputs: dict[Callable, Sequence[tuple[Any, ...]] | None], + _receive_inflate_expr: list[str] | None = None, # For debugging. + info: dict[Callable, list[str]] | None = None, # Optional argument to provide info about the function or its inputs skip_size_check=False, ) -> None: """Add bundled sample inputs to a model for an arbitrary list of public functions. @@ -366,7 +366,7 @@ def get_bundled_inputs_functions_and_info(self): def _inflate_expr( arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False -) -> tuple[Union[T, torch.Tensor], str, Optional[str]]: +) -> tuple[T | torch.Tensor, str, str | None]: # Allow custom inflation expressions any object. # For example, calling custom image-decoding ops. # Or just use "{}" as the format string to ignore size limits. diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index fc16c38b8e3e4..35e1848c695e1 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -23,7 +23,6 @@ import torch._appdirs from .file_baton import FileBaton from ._cpp_extension_versioner import ExtensionVersioner -from typing import Optional, Union from typing_extensions import deprecated from torch.torch_version import TorchVersion, Version @@ -82,7 +81,7 @@ "verify_ninja_availability", "remove_extension_h_precompiler_headers", "get_cxx_compiler", "check_compiler_is_gcc"] # Taken directly from python stdlib < 3.9 # See https://github.com/pytorch/pytorch/issues/48617 -def _nt_quote_args(args: Optional[list[str]]) -> list[str]: +def _nt_quote_args(args: list[str] | None) -> list[str]: """Quote command-line arguments for DOS/Windows conventions. Just wraps every argument which contains blanks in double quotes, and @@ -93,7 +92,7 @@ def _nt_quote_args(args: Optional[list[str]]) -> list[str]: return [] return [f'"{arg}"' if ' ' in arg else arg for arg in args] -def _find_cuda_home() -> Optional[str]: +def _find_cuda_home() -> str | None: """Find the CUDA install path.""" # Guess #1 cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') @@ -119,7 +118,7 @@ def _find_cuda_home() -> Optional[str]: logger.warning("No CUDA runtime is found, using CUDA_HOME='%s'", cuda_home) return cuda_home -def _find_rocm_home() -> Optional[str]: +def _find_rocm_home() -> str | None: """Find the ROCm install path.""" # Guess #1 rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') @@ -141,7 +140,7 @@ def _find_rocm_home() -> Optional[str]: logger.warning("No ROCm runtime is found, using ROCM_HOME='%s'", rocm_home) return rocm_home -def _find_sycl_home() -> Optional[str]: +def _find_sycl_home() -> str | None: sycl_home = None icpx_path = shutil.which('icpx') # Guess 1: for source code build developer/user, we'll have icpx in PATH, @@ -1547,7 +1546,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str return paths -def library_paths(device_type: str = "cpu", torch_include_dirs: bool = True, cross_target_platform: Optional[str] = None) -> list[str]: +def library_paths(device_type: str = "cpu", torch_include_dirs: bool = True, cross_target_platform: str | None = None) -> list[str]: """ Get the library paths required to build a C++ or CUDA extension. @@ -1605,7 +1604,7 @@ def library_paths(device_type: str = "cpu", torch_include_dirs: bool = True, cro def load(name, - sources: Union[str, list[str]], + sources: str | list[str], extra_cflags=None, extra_cuda_cflags=None, extra_sycl_cflags=None, @@ -1613,8 +1612,8 @@ def load(name, extra_include_paths=None, build_directory=None, verbose=False, - with_cuda: Optional[bool] = None, - with_sycl: Optional[bool] = None, + with_cuda: bool | None = None, + with_sycl: bool | None = None, is_python_module=True, is_standalone=False, keep_intermediates=True): @@ -2098,11 +2097,11 @@ def _jit_compile(name, extra_include_paths, build_directory: str, verbose: bool, - with_cuda: Optional[bool], - with_sycl: Optional[bool], + with_cuda: bool | None, + with_sycl: bool | None, is_python_module, is_standalone, - keep_intermediates=True) -> Union[types.ModuleType, str]: + keep_intermediates=True) -> types.ModuleType | str: if is_python_module and is_standalone: raise ValueError("`is_python_module` and `is_standalone` are mutually exclusive.") @@ -2204,8 +2203,8 @@ def _write_ninja_file_and_compile_objects( sycl_dlink_post_cflags, build_directory: str, verbose: bool, - with_cuda: Optional[bool], - with_sycl: Optional[bool]) -> None: + with_cuda: bool | None, + with_sycl: bool | None) -> None: verify_ninja_availability() compiler = get_cxx_compiler() @@ -2262,8 +2261,8 @@ def _write_ninja_file_and_build_library( extra_include_paths, build_directory: str, verbose: bool, - with_cuda: Optional[bool], - with_sycl: Optional[bool], + with_cuda: bool | None, + with_sycl: bool | None, is_standalone: bool = False) -> None: verify_ninja_availability() @@ -2392,7 +2391,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): return extra_ldflags -def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: +def _get_cuda_arch_flags(cflags: list[str] | None = None) -> list[str]: """ Determine CUDA arch flags to use. @@ -2498,7 +2497,7 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: return sorted(set(flags)) -def _get_rocm_arch_flags(cflags: Optional[list[str]] = None) -> list[str]: +def _get_rocm_arch_flags(cflags: list[str] | None = None) -> list[str]: # If cflags is given, there may already be user-provided arch flags in it # (from `extra_compile_args`). If user also specified -fgpu-rdc or -fno-gpu-rdc, we # assume they know what they're doing. Otherwise, we force -fno-gpu-rdc default. @@ -2562,7 +2561,7 @@ def _get_build_directory(name: str, verbose: bool) -> str: return build_directory -def _get_num_workers(verbose: bool) -> Optional[int]: +def _get_num_workers(verbose: bool) -> int | None: max_jobs = os.environ.get('MAX_JOBS') if max_jobs is not None and max_jobs.isdigit(): if verbose: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index cb051f6642dcf..733e84a9afae6 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -13,7 +13,6 @@ import copy import re from collections.abc import Callable -from typing import Optional, Union import torch @@ -119,7 +118,7 @@ def default_convert(data): def collate( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): r""" General collate function that handles collection type of element within each batch. @@ -247,7 +246,7 @@ def collate( def collate_tensor_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): elem = batch[0] out = None @@ -279,7 +278,7 @@ def collate_tensor_fn( def collate_numpy_array_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): elem = batch[0] # array of string classes and object @@ -292,7 +291,7 @@ def collate_numpy_array_fn( def collate_numpy_scalar_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): return torch.as_tensor(batch) @@ -300,7 +299,7 @@ def collate_numpy_scalar_fn( def collate_float_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): return torch.tensor(batch, dtype=torch.float64) @@ -308,7 +307,7 @@ def collate_float_fn( def collate_int_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): return torch.tensor(batch) @@ -316,12 +315,12 @@ def collate_int_fn( def collate_str_fn( batch, *, - collate_fn_map: Optional[dict[Union[type, tuple[type, ...]], Callable]] = None, + collate_fn_map: dict[type | tuple[type, ...], Callable] | None = None, ): return batch -default_collate_fn_map: dict[Union[type, tuple[type, ...]], Callable] = { +default_collate_fn_map: dict[type | tuple[type, ...], Callable] = { torch.Tensor: collate_tensor_fn } with contextlib.suppress(ImportError): diff --git a/torch/utils/data/_utils/worker.py b/torch/utils/data/_utils/worker.py index c2d9294db86d9..611aee4766bf4 100644 --- a/torch/utils/data/_utils/worker.py +++ b/torch/utils/data/_utils/worker.py @@ -9,7 +9,7 @@ import queue import random from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING import torch from torch._utils import ExceptionWrapper @@ -98,7 +98,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(items)})" -def get_worker_info() -> Optional[WorkerInfo]: +def get_worker_info() -> WorkerInfo | None: r"""Returns the information about the current :class:`~torch.utils.data.DataLoader` iterator worker process. @@ -140,7 +140,7 @@ class _IterableDatasetStopIteration: @dataclass(frozen=True) class _ResumeIteration: - seed: Optional[int] = None + seed: int | None = None # The function `_generate_state` is adapted from `numpy.random.SeedSequence` @@ -349,7 +349,7 @@ def _worker_loop( # processing steps. continue idx, index = r - data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + data: _IterableDatasetStopIteration | ExceptionWrapper if init_exception is not None: data = init_exception init_exception = None diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 467e8c655d2bc..35e70d686a34f 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -17,7 +17,7 @@ import threading import warnings from collections.abc import Callable -from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, NoReturn, TYPE_CHECKING, TypeVar from typing_extensions import Self import torch @@ -233,34 +233,34 @@ class DataLoader(Generic[_T_co]): """ dataset: Dataset[_T_co] - batch_size: Optional[int] + batch_size: int | None num_workers: int pin_memory: bool drop_last: bool timeout: float - sampler: Union[Sampler, Iterable] + sampler: Sampler | Iterable pin_memory_device: str - prefetch_factor: Optional[int] - _iterator: Optional[_BaseDataLoaderIter] + prefetch_factor: int | None + _iterator: _BaseDataLoaderIter | None __initialized = False def __init__( self, dataset: Dataset[_T_co], - batch_size: Optional[int] = 1, - shuffle: Optional[bool] = None, - sampler: Union[Sampler, Iterable, None] = None, - batch_sampler: Union[Sampler[list], Iterable[list], None] = None, + batch_size: int | None = 1, + shuffle: bool | None = None, + sampler: Sampler | Iterable | None = None, + batch_sampler: Sampler[list] | Iterable[list] | None = None, num_workers: int = 0, - collate_fn: Optional[_collate_fn_t] = None, + collate_fn: _collate_fn_t | None = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, - worker_init_fn: Optional[_worker_init_fn_t] = None, + worker_init_fn: _worker_init_fn_t | None = None, multiprocessing_context=None, generator=None, *, - prefetch_factor: Optional[int] = None, + prefetch_factor: int | None = None, persistent_workers: bool = False, pin_memory_device: str = "", in_order: bool = True, diff --git a/torch/utils/data/datapipes/_decorator.py b/torch/utils/data/datapipes/_decorator.py index 507e00259c4c7..0289668c03abc 100644 --- a/torch/utils/data/datapipes/_decorator.py +++ b/torch/utils/data/datapipes/_decorator.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable from functools import wraps -from typing import Any, get_type_hints, Optional, Union +from typing import Any, get_type_hints from torch.utils.data.datapipes._typing import _DataPipeMeta from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe @@ -73,11 +73,11 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: class non_deterministic: - cls: Optional[type[IterDataPipe]] = None + cls: type[IterDataPipe] | None = None # TODO: Lambda for picking deterministic_fn: Callable[..., bool] - def __init__(self, arg: Union[type[IterDataPipe], Callable[..., bool]]) -> None: + def __init__(self, arg: type[IterDataPipe] | Callable[..., bool]) -> None: # 1. Decorator doesn't have any argument if isinstance(arg, type): # type: ignore[arg-type] if not issubclass(arg, IterDataPipe): # type: ignore[arg-type] diff --git a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py index 410683bcfbd70..9cfc5c268a174 100644 --- a/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py +++ b/torch/utils/data/datapipes/dataframe/dataframe_wrapper.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs -from typing import Any, Optional +from typing import Any _pandas: Any = None -_WITH_PANDAS: Optional[bool] = None +_WITH_PANDAS: bool | None = None def _try_import_pandas() -> bool: diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index e8b03ff3b2afa..463f7384aa6c4 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, NoReturn, Optional +from typing import Any, NoReturn from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.dataframe.structures import DataChunkDF @@ -454,7 +454,7 @@ def __getattr__(self, attrname): # ? @functional_datapipe("trace_as_dataframe") class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe): # type: ignore[misc] - source_datapipe: Optional[Any] = None + source_datapipe: Any | None = None # TODO(VitalyFedyunin): Must implement all special functions of datapipes diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 4b3913bc82369..51c1689008530 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -1,7 +1,7 @@ import functools import pickle from collections.abc import Callable, Iterable, Iterator -from typing import Optional, TypeVar +from typing import TypeVar from torch.utils._import_utils import import_dill from torch.utils.data.datapipes._hook_iterator import _SnapshotState @@ -125,14 +125,14 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): """ functions: dict[str, Callable] = {} - reduce_ex_hook: Optional[Callable] = None - getstate_hook: Optional[Callable] = None - str_hook: Optional[Callable] = None - repr_hook: Optional[Callable] = None - _valid_iterator_id: Optional[int] = None + reduce_ex_hook: Callable | None = None + getstate_hook: Callable | None = None + str_hook: Callable | None = None + repr_hook: Callable | None = None + _valid_iterator_id: int | None = None _number_of_samples_yielded: int = 0 _snapshot_state: _SnapshotState = _SnapshotState.NotStarted - _fast_forward_iterator: Optional[Iterator] = None + _fast_forward_iterator: Iterator | None = None def __iter__(self) -> Iterator[_T_co]: # pyrefly: ignore [bad-return] @@ -281,10 +281,10 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): """ functions: dict[str, Callable] = {} - reduce_ex_hook: Optional[Callable] = None - getstate_hook: Optional[Callable] = None - str_hook: Optional[Callable] = None - repr_hook: Optional[Callable] = None + reduce_ex_hook: Callable | None = None + getstate_hook: Callable | None = None + str_hook: Callable | None = None + repr_hook: Callable | None = None def __getattr__(self, attribute_name): if attribute_name in MapDataPipe.functions: @@ -408,7 +408,7 @@ class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataP def __init__(self, datapipe: IterDataPipe[_T_co]) -> None: super().__init__(datapipe) # pyrefly: ignore [invalid-type-var] - self._datapipe_iter: Optional[Iterator[_T_co]] = None + self._datapipe_iter: Iterator[_T_co] | None = None def __iter__(self) -> "_IterDataPipeSerializationWrapper": self._datapipe_iter = iter(self._datapipe) diff --git a/torch/utils/data/datapipes/datapipe.pyi.in b/torch/utils/data/datapipes/datapipe.pyi.in index a7b7bac21f50d..084f253b5ddbe 100644 --- a/torch/utils/data/datapipes/datapipe.pyi.in +++ b/torch/utils/data/datapipes/datapipe.pyi.in @@ -6,7 +6,7 @@ # classes/objects here, even though we are not injecting extra code into them at the moment. from collections.abc import Callable, Iterable, Iterator -from typing import Any, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeVar from torch.utils.data import Dataset, default_collate, IterableDataset from torch.utils.data.datapipes._hook_iterator import _SnapshotState @@ -48,13 +48,13 @@ class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): functions: dict[str, Callable] = ... - reduce_ex_hook: Optional[Callable] = ... - getstate_hook: Optional[Callable] = ... - str_hook: Optional[Callable] = ... - repr_hook: Optional[Callable] = ... + reduce_ex_hook: Callable | None = ... + getstate_hook: Callable | None = ... + str_hook: Callable | None = ... + repr_hook: Callable | None = ... _number_of_samples_yielded: int = ... _snapshot_state: _SnapshotState = _SnapshotState.Iterating # noqa: PYI015 - _fast_forward_iterator: Optional[Iterator] = ... + _fast_forward_iterator: Iterator | None = ... def __getattr__(self, attribute_name: Any): ... @classmethod def register_function(cls, function_name: Any, function: Any) -> None: ... diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index 23fd20f602567..90f9d80a2e7fe 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from pathlib import Path -from typing import Any, Union +from typing import Any from typing_extensions import deprecated @@ -225,7 +225,7 @@ def process_signature(line: str) -> list[str]: def get_method_definitions( - file_path: Union[str, list[str]], + file_path: str | list[str], files_to_exclude: set[str], deprecated_files: set[str], default_output_type: str, diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 2e3bb18c80bb7..af1d9792c097b 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -2,7 +2,7 @@ import functools from collections import namedtuple from collections.abc import Callable, Iterator, Sized -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar import torch from torch.utils.data._utils.collate import default_collate @@ -226,10 +226,10 @@ class CollatorIterDataPipe(MapperIterDataPipe): def __init__( self, datapipe: IterDataPipe, - conversion: Union[ - Callable[..., Any], dict[Union[str, Any], Union[Callable, Any]], None - ] = default_collate, - collate_fn: Optional[Callable] = None, + conversion: Callable[..., Any] + | dict[str | Any, Callable | Any] + | None = default_collate, + collate_fn: Callable | None = None, ) -> None: # TODO(VitalyFedyunin): Replace `Callable[..., Any]` with `Callable[[IColumn], Any]` # TODO(VitalyFedyunin): Replace with `Dict[Union[str, IColumn], Union[Callable, Enum]]` diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index 6b4f134ef917d..79a774c5e63db 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import random from collections.abc import Iterator, Sized -from typing import Optional, TypeVar +from typing import TypeVar import torch from torch.utils.data.datapipes._decorator import functional_datapipe @@ -35,8 +35,8 @@ def __init__( self, datapipe: IterDataPipe, sampler: type[Sampler] = SequentialSampler, - sampler_args: Optional[tuple] = None, - sampler_kwargs: Optional[dict] = None, + sampler_args: tuple | None = None, + sampler_kwargs: dict | None = None, ) -> None: if not isinstance(datapipe, Sized): raise AssertionError( @@ -99,7 +99,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]): buffer_size: int _buffer: list[_T_co] _enabled: bool - _seed: Optional[int] + _seed: int | None _rng: random.Random def __init__( diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 4682a483170f5..4915e4c3d7c52 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import deque from collections.abc import Callable, Iterator, Sized -from typing import Any, Literal, Optional, TypeVar +from typing import Any, Literal, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes._hook_iterator import _SnapshotState @@ -101,7 +101,7 @@ def __new__( datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000, - copy: Optional[Literal["shallow", "deep"]] = None, + copy: Literal["shallow", "deep"] | None = None, ): if num_instances < 1: raise ValueError( @@ -147,10 +147,10 @@ def __init__( datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000, - copy: Optional[Literal["shallow", "deep"]] = None, + copy: Literal["shallow", "deep"] | None = None, ) -> None: self.main_datapipe = datapipe - self._datapipe_iterator: Optional[Iterator[Any]] = None + self._datapipe_iterator: Iterator[Any] | None = None self.num_instances = num_instances self.buffer: deque = deque() self.buffer_size = buffer_size @@ -177,7 +177,7 @@ def __init__( ] * num_instances # Indicate the indices of the next element to get self.slowest_ptr = 0 # The index to read by the slowest child self.leading_ptr = 0 # The index to read by the fastest child - self.end_ptr: Optional[int] = None # The index to stop child + self.end_ptr: int | None = None # The index to stop child self._child_stop: list[bool] = [True for _ in range(num_instances)] def __len__(self) -> int: @@ -420,7 +420,7 @@ def __new__( cls, datapipe: IterDataPipe, num_instances: int, - classifier_fn: Callable[[_T_co], Optional[int]], + classifier_fn: Callable[[_T_co], int | None], drop_none: bool = False, buffer_size: int = 1000, ): @@ -452,13 +452,13 @@ def __init__( self, datapipe: IterDataPipe[_T_co], num_instances: int, - classifier_fn: Callable[[_T_co], Optional[int]], + classifier_fn: Callable[[_T_co], int | None], drop_none: bool, buffer_size: int, ) -> None: # pyrefly: ignore [invalid-type-var] self.main_datapipe = datapipe - self._datapipe_iterator: Optional[Iterator[Any]] = None + self._datapipe_iterator: Iterator[Any] | None = None self.num_instances = num_instances self.buffer_size = buffer_size if self.buffer_size < 0: @@ -582,7 +582,7 @@ def __setstate__(self, state): self._child_stop = [True for _ in range(self.num_instances)] self.main_datapipe_exhausted = False - def _cleanup(self, instance_id: Optional[int] = None) -> None: + def _cleanup(self, instance_id: int | None = None) -> None: ids = ( range(self.num_instances) if instance_id is None diff --git a/torch/utils/data/datapipes/iter/filelister.py b/torch/utils/data/datapipes/iter/filelister.py index 2b3d16bed2a66..352d3c01e12d2 100644 --- a/torch/utils/data/datapipes/iter/filelister.py +++ b/torch/utils/data/datapipes/iter/filelister.py @@ -1,5 +1,4 @@ from collections.abc import Iterator, Sequence -from typing import Union from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -36,8 +35,8 @@ class FileListerIterDataPipe(IterDataPipe[str]): def __init__( self, - root: Union[str, Sequence[str], IterDataPipe] = ".", - masks: Union[str, list[str]] = "", + root: str | Sequence[str] | IterDataPipe = ".", + masks: str | list[str] = "", *, recursive: bool = False, abspath: bool = False, @@ -50,7 +49,7 @@ def __init__( if not isinstance(root, IterDataPipe): root = IterableWrapperIterDataPipe(root) self.datapipe: IterDataPipe = root - self.masks: Union[str, list[str]] = masks + self.masks: str | list[str] = masks self.recursive: bool = recursive self.abspath: bool = abspath self.non_deterministic: bool = non_deterministic diff --git a/torch/utils/data/datapipes/iter/fileopener.py b/torch/utils/data/datapipes/iter/fileopener.py index 1d8efef4849bf..e77f7a4c8e660 100644 --- a/torch/utils/data/datapipes/iter/fileopener.py +++ b/torch/utils/data/datapipes/iter/fileopener.py @@ -1,6 +1,5 @@ from collections.abc import Iterable, Iterator from io import IOBase -from typing import Optional from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -48,13 +47,13 @@ def __init__( self, datapipe: Iterable[str], mode: str = "r", - encoding: Optional[str] = None, + encoding: str | None = None, length: int = -1, ) -> None: super().__init__() self.datapipe: Iterable[str] = datapipe self.mode: str = mode - self.encoding: Optional[str] = encoding + self.encoding: str | None = encoding if self.mode not in ("b", "t", "rb", "rt", "r"): raise ValueError(f"Invalid mode {mode}") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 16ae0965f3cff..b773f06823a76 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from collections import defaultdict from collections.abc import Callable, Iterator, Sized -from typing import Any, NoReturn, Optional, TypeVar +from typing import Any, NoReturn, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe @@ -199,8 +199,8 @@ def __init__( *, keep_key: bool = False, buffer_size: int = 10000, - group_size: Optional[int] = None, - guaranteed_group_size: Optional[int] = None, + group_size: int | None = None, + guaranteed_group_size: int | None = None, drop_remaining: bool = False, ) -> None: _check_unpickable_fn(group_key_fn) diff --git a/torch/utils/data/datapipes/iter/streamreader.py b/torch/utils/data/datapipes/iter/streamreader.py index ece25b3467cdb..1129c06548e1f 100644 --- a/torch/utils/data/datapipes/iter/streamreader.py +++ b/torch/utils/data/datapipes/iter/streamreader.py @@ -1,6 +1,5 @@ from collections.abc import Iterator from io import IOBase -from typing import Optional from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe @@ -31,7 +30,7 @@ class StreamReaderIterDataPipe(IterDataPipe[tuple[str, bytes]]): """ def __init__( - self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: Optional[int] = None + self, datapipe: IterDataPipe[tuple[str, IOBase]], chunk: int | None = None ) -> None: self.datapipe = datapipe self.chunk = chunk diff --git a/torch/utils/data/datapipes/map/combinatorics.py b/torch/utils/data/datapipes/map/combinatorics.py index 4876ce3fd1cbc..af4792fc805b8 100644 --- a/torch/utils/data/datapipes/map/combinatorics.py +++ b/torch/utils/data/datapipes/map/combinatorics.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import random from collections.abc import Iterator -from typing import Optional, TypeVar +from typing import TypeVar import torch from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe @@ -53,14 +53,14 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]): datapipe: MapDataPipe[_T_co] _enabled: bool - _seed: Optional[int] + _seed: int | None _rng: random.Random def __init__( self, datapipe: MapDataPipe[_T_co], *, - indices: Optional[list] = None, + indices: list | None = None, ) -> None: super().__init__() self.datapipe = datapipe diff --git a/torch/utils/data/datapipes/map/utils.py b/torch/utils/data/datapipes/map/utils.py index 360f66b3137c7..a5b9075f1dbbc 100644 --- a/torch/utils/data/datapipes/map/utils.py +++ b/torch/utils/data/datapipes/map/utils.py @@ -1,7 +1,7 @@ import copy import warnings from collections.abc import Mapping, Sequence -from typing import Any, TypeVar, Union +from typing import Any, TypeVar from torch.utils.data.datapipes.datapipe import MapDataPipe @@ -36,10 +36,10 @@ class SequenceWrapperMapDataPipe(MapDataPipe[_T]): 100 """ - sequence: Union[Sequence[_T], Mapping[Any, _T]] + sequence: Sequence[_T] | Mapping[Any, _T] def __init__( - self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True + self, sequence: Sequence[_T] | Mapping[Any, _T], deepcopy: bool = True ) -> None: if deepcopy: try: diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 7f27c2f37fc93..6032de7166af7 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Iterable from io import IOBase -from typing import Any, NoReturn, Optional, Union +from typing import Any, NoReturn from torch.utils._import_utils import dill_available @@ -25,9 +25,7 @@ DILL_AVAILABLE = dill_available() -def validate_input_col( - fn: Callable, input_col: Optional[Union[int, tuple, list]] -) -> None: +def validate_input_col(fn: Callable, input_col: int | tuple | list | None) -> None: """ Check that function used in a callable datapipe works with the input column. @@ -166,7 +164,7 @@ def _check_unpickable_fn(fn: Callable) -> None: return -def match_masks(name: str, masks: Union[str, list[str]]) -> bool: +def match_masks(name: str, masks: str | list[str]) -> bool: # empty mask matches any input name if not masks: return True @@ -182,7 +180,7 @@ def match_masks(name: str, masks: Union[str, list[str]]) -> bool: def get_file_pathnames_from_root( root: str, - masks: Union[str, list[str]], + masks: str | list[str], recursive: bool = False, abspath: bool = False, non_deterministic: bool = False, @@ -219,7 +217,7 @@ def onerror(err: OSError) -> NoReturn: def get_file_binaries_from_pathnames( - pathnames: Iterable, mode: str, encoding: Optional[str] = None + pathnames: Iterable, mode: str, encoding: str | None = None ): if not isinstance(pathnames, Iterable): pathnames = [ diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index c800dd6a05826..19ec449f040dd 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -10,7 +10,7 @@ # targets fail to typecheck with: # TypeError: Cannot create a consistent method resolution order (MRO) for # bases Iterable, Generic -from typing import cast, Generic, Iterable, Optional, TypeVar, Union # noqa: UP035 +from typing import cast, Generic, Iterable, TypeVar # noqa: UP035 from typing_extensions import deprecated # No 'default_generator' in torch/__init__.pyi @@ -228,7 +228,7 @@ class StackDataset(Dataset[_T_stack]): **kwargs (Dataset): Datasets for stacking returned as dict. """ - datasets: Union[tuple, dict] + datasets: tuple | dict def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: if args: @@ -418,8 +418,8 @@ def __len__(self) -> int: def random_split( dataset: Dataset[_T], - lengths: Sequence[Union[int, float]], - generator: Optional[Generator] = default_generator, + lengths: Sequence[int | float], + generator: Generator | None = default_generator, ) -> list[Subset[_T]]: r""" Randomly split a dataset into non-overlapping new datasets of given lengths. diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index b2f4eb04e8e24..5179d7698ffee 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -1,6 +1,6 @@ import math from collections.abc import Iterator -from typing import Optional, TypeVar +from typing import TypeVar import torch import torch.distributed as dist @@ -66,8 +66,8 @@ class DistributedSampler(Sampler[_T_co]): def __init__( self, dataset: Dataset, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, + num_replicas: int | None = None, + rank: int | None = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 052db781d6a8d..d1e7e679ad5d5 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -3,7 +3,7 @@ import pickle import warnings from collections.abc import Collection -from typing import Optional, Union +from typing import Union from torch.utils._import_utils import dill_available from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe @@ -106,7 +106,7 @@ def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: return _traverse_helper(datapipe, only_datapipe=True, cache=cache) -def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph: +def traverse(datapipe: DataPipe, only_datapipe: bool | None = None) -> DataPipeGraph: r""" Traverse the DataPipes and their attributes to extract the DataPipe graph. diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py index 9030150116800..03096398a6738 100644 --- a/torch/utils/data/graph_settings.py +++ b/torch/utils/data/graph_settings.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import inspect import warnings -from typing import Any, Optional +from typing import Any from typing_extensions import deprecated import torch @@ -94,9 +94,7 @@ def _is_shuffle_datapipe(datapipe: DataPipe) -> bool: ) -def apply_shuffle_settings( - datapipe: DataPipe, shuffle: Optional[bool] = None -) -> DataPipe: +def apply_shuffle_settings(datapipe: DataPipe, shuffle: bool | None = None) -> DataPipe: r""" Traverse the graph of ``DataPipes`` to find and set shuffle attribute. diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index f36f15ee09589..aa13bb8e0a3e1 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import itertools from collections.abc import Iterable, Iterator, Sequence, Sized -from typing import Generic, Optional, TypeVar, Union +from typing import Generic, TypeVar import torch @@ -132,7 +132,7 @@ def __init__( self, data_source: Sized, replacement: bool = False, - num_samples: Optional[int] = None, + num_samples: int | None = None, generator=None, ) -> None: self.data_source = data_source @@ -307,7 +307,7 @@ class BatchSampler(Sampler[list[int]]): def __init__( self, - sampler: Union[Sampler[int], Iterable[int]], + sampler: Sampler[int] | Iterable[int], batch_size: int, drop_last: bool, ) -> None: diff --git a/torch/utils/data/typing.ipynb b/torch/utils/data/typing.ipynb index 0f546a2b3c3b5..b25d82d421d9b 100644 --- a/torch/utils/data/typing.ipynb +++ b/torch/utils/data/typing.ipynb @@ -208,7 +208,7 @@ "\n", "T = TypeVar('T', int, str) # equals to Union[int, str]\n", "class DP(IterDataPipe[tuple[T, str]]):\n", - " def __iter__(self) -> Iterator[tuple[Union[int, str], str]]:\n", + " def __iter__(self) -> Iterator[tuple[int | str, str]]:\n", " pass\n", "print(DP.type)" ] @@ -313,7 +313,7 @@ "\n", "class DP(IterDataPipe):\n", " @argument_validation\n", - " def __init__(self, dp: IterDataPipe[Union[int, tuple]]) -> None:\n", + " def __init__(self, dp: IterDataPipe[int | tuple]) -> None:\n", " self.dp = dp\n", "\n", " def __iter__(self):\n", diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index f63cc89cc26ea..223cca54dafed 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import torch import enum @@ -58,8 +58,8 @@ class DLDeviceType(enum.IntEnum): def from_dlpack( ext_tensor: Any, *, - device: Optional[_Device] = None, - copy: Optional[bool] = None + device: _Device | None = None, + copy: bool | None = None ) -> 'torch.Tensor': """from_dlpack(ext_tensor) -> Tensor diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 41e5bc056e258..7d08a14158300 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -2,7 +2,7 @@ import torch from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from .module_tracker import ModuleTracker -from typing import Any, Optional, Union, TypeVar +from typing import Any, TypeVar from collections.abc import Callable from collections.abc import Iterator from typing_extensions import ParamSpec @@ -314,7 +314,7 @@ def _unpack_flash_attention_nested_shapes( cum_seq_k, max_q, max_k, -) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]: +) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]: """ Given inputs to a flash_attention_(forward|backward) kernel, this will handle behavior for NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for @@ -366,7 +366,7 @@ def _unpack_efficient_attention_nested_shapes( cu_seqlens_k, max_seqlen_q, max_seqlen_k, -) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], Optional[tuple[int, ...]]]]: +) -> Iterator[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...] | None]]: """ Given inputs to a efficient_attention_(forward|backward) kernel, this will handle behavior for NestedTensor inputs by effectively unbinding the NestedTensor and yielding the shapes for @@ -661,15 +661,15 @@ class FlopCounterMode: def __init__( self, - mods: Optional[Union[torch.nn.Module, list[torch.nn.Module]]] = None, + mods: torch.nn.Module | list[torch.nn.Module] | None = None, depth: int = 2, display: bool = True, - custom_mapping: Optional[dict[Any, Any]] = None) -> None: + custom_mapping: dict[Any, Any] | None = None) -> None: super().__init__() self.flop_counts: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.depth = depth self.display = display - self.mode: Optional[_FlopCounterMode] = None + self.mode: _FlopCounterMode | None = None if custom_mapping is None: custom_mapping = {} if mods is not None: diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 29d02cb30d338..1d2f5964fcaf8 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -35,7 +35,6 @@ from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS from .cuda_to_hip_mappings import MATH_TRANSPILATIONS -from typing import Optional from collections.abc import Iterator from collections.abc import Mapping, Iterable from enum import Enum @@ -1115,7 +1114,7 @@ def hipify( hip_clang_launch: bool = False, is_pytorch_extension: bool = False, hipify_extra_files_only: bool = False, - clean_ctx: Optional[GeneratedFileCleaner] = None + clean_ctx: GeneratedFileCleaner | None = None ) -> HipifyFinalResult: if project_directory == "": project_directory = os.getcwd() diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index 819f19d5b71ea..1ad0a65204a47 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -4,7 +4,7 @@ import torch from enum import Enum from torch._C import _MobileOptimizerType as MobileOptimizerType -from typing import Optional, AnyStr +from typing import AnyStr class LintCode(Enum): BUNDLED_INPUT = 1 @@ -14,8 +14,8 @@ class LintCode(Enum): def optimize_for_mobile( script_module: torch.jit.ScriptModule, - optimization_blocklist: Optional[set[MobileOptimizerType]] = None, - preserved_methods: Optional[list[AnyStr]] = None, + optimization_blocklist: set[MobileOptimizerType] | None = None, + preserved_methods: list[AnyStr] | None = None, backend: str = 'CPU') -> torch.jit.RecursiveScriptModule: """ Optimize a torch script module for mobile deployment. diff --git a/torch/utils/serialization/config.py b/torch/utils/serialization/config.py index 0a3fba9f5b82f..c3e6729c68583 100644 --- a/torch/utils/serialization/config.py +++ b/torch/utils/serialization/config.py @@ -12,7 +12,7 @@ class load: mmap: bool = False endianness: _Optional["_LoadEndianess"] = None # MAP_PRIVATE = 2 - mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2 + mmap_flags: int | None = None if sys.platform == "win32" else 2 calculate_storage_offsets: bool = False diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index c32be5b2cae36..b79ba0dfac048 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -1,6 +1,5 @@ import torch -from typing import Optional, Union from collections.abc import Sequence from tensorboard.compat.proto.node_def_pb2 import NodeDef from tensorboard.compat.proto.attr_value_pb2 import AttrValue @@ -8,7 +7,7 @@ # pyrefly: ignore [not-a-type] -def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]: +def attr_value_proto(dtype: object, shape: Sequence[int] | None, s: str | None) -> dict[str, AttrValue]: """Create a dict of objects matching a NodeDef's attr field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto @@ -38,10 +37,10 @@ def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: def node_proto( name: str, op: str = "UnSpecified", - input: Optional[Union[list[str], str]] = None, - dtype: Optional[torch.dtype] = None, - shape: Optional[tuple[int, ...]] = None, - outputsize: Optional[Sequence[int]] = None, + input: list[str] | str | None = None, + dtype: torch.dtype | None = None, + shape: tuple[int, ...] | None = None, + outputsize: Sequence[int] | None = None, attributes: str = "", ) -> NodeDef: # pyrefly: ignore [not-a-type] """Create an object matching a NodeDef. diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 74befc366c199..41469a2855421 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -4,7 +4,7 @@ import os import struct -from typing import Any, Optional +from typing import Any import torch import numpy as np @@ -249,7 +249,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): ssi.hparams[k].number_value = v if k in hparam_domain_discrete: - domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue( + domain_discrete: struct_pb2.ListValue | None = struct_pb2.ListValue( values=[ struct_pb2.Value(number_value=d) for d in hparam_domain_discrete[k] diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 2dd8ac3db667b..008ae59e94e6a 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -3,7 +3,7 @@ import os import time -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Union import torch @@ -733,9 +733,9 @@ def add_figure( self, tag: str, figure: Union["Figure", list["Figure"]], - global_step: Optional[int] = None, + global_step: int | None = None, close: bool = True, - walltime: Optional[float] = None, + walltime: float | None = None, ) -> None: """Render matplotlib figure into an image and add it to summary. diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 0002b40025c18..43f3410fe5e46 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import gc import sys -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple import types import weakref import json @@ -256,7 +256,7 @@ def format_sequence(obj): class Node(NamedTuple): label: str - context: Optional[str] + context: str | None root: bool referrents: list[tuple[str, int]] From b91a2ab892cab2f9951834a176efa103ac1c0599 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Sun, 9 Nov 2025 13:38:14 +0000 Subject: [PATCH 268/651] [2/N] Use context managers (#167404) This PR fixes more context manager usage in Python code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167404 Approved by: https://github.com/mlazos --- test/custom_backend/test_custom_backend.py | 8 +- test/custom_operator/test_custom_ops.py | 10 +- test/dynamo/test_logging.py | 6 +- test/export/test_nativert.py | 165 +++++++++--------- test/test_serialization.py | 44 ++--- torch/_inductor/utils.py | 1 - torch/profiler/profiler.py | 29 ++- .../_internal/distributed/distributed_test.py | 19 +- .../utils/valgrind_wrapper/timer_interface.py | 5 +- torch/utils/data/dataloader.py | 2 +- torch/utils/data/datapipes/utils/common.py | 2 +- torch/utils/tensorboard/summary.py | 31 ++-- torch/utils/viz/_cycles.py | 2 +- 13 files changed, 148 insertions(+), 176 deletions(-) diff --git a/test/custom_backend/test_custom_backend.py b/test/custom_backend/test_custom_backend.py index 269cc98418c86..d0e518c4cd125 100644 --- a/test/custom_backend/test_custom_backend.py +++ b/test/custom_backend/test_custom_backend.py @@ -1,6 +1,5 @@ # Owner(s): ["module: unknown"] -import os import tempfile from backend import get_custom_backend_library_path, Model, to_custom_backend @@ -41,14 +40,11 @@ def test_save_load(self): self.test_execute() # Save and load. - f = tempfile.NamedTemporaryFile(delete=False) - try: + with tempfile.NamedTemporaryFile() as f: f.close() torch.jit.save(self.model, f.name) loaded = torch.jit.load(f.name) - finally: - os.unlink(f.name) - self.model = loaded + self.model = loaded # Test execution again. self.test_execute() diff --git a/test/custom_operator/test_custom_ops.py b/test/custom_operator/test_custom_ops.py index e66ca04ec5c32..8d43ed10d0ede 100644 --- a/test/custom_operator/test_custom_ops.py +++ b/test/custom_operator/test_custom_ops.py @@ -1,6 +1,5 @@ # Owner(s): ["module: unknown"] -import os.path import sys import tempfile import unittest @@ -144,16 +143,13 @@ def test_saving_and_loading_script_module_with_custom_op(self): # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile # opens the file, and it cannot be opened multiple times in Windows. To support Windows, # close the file after creation and try to remove it manually. - file = tempfile.NamedTemporaryFile(delete=False) - try: + with tempfile.NamedTemporaryFile() as file: file.close() model.save(file.name) loaded = torch.jit.load(file.name) - finally: - os.unlink(file.name) - output = loaded.forward(torch.ones(5)) - self.assertTrue(output.allclose(torch.ones(5) + 1)) + output = loaded.forward(torch.ones(5)) + self.assertTrue(output.allclose(torch.ones(5) + 1)) if __name__ == "__main__": diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 162bc5c111d07..f472705101e35 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -861,7 +861,7 @@ def fn(a): def test_logs_out(self): import tempfile - with tempfile.NamedTemporaryFile(delete=False) as tmp: + with tempfile.NamedTemporaryFile(delete=True) as tmp: file_path = _as_posix_path(tmp.name) """ NamedTemporaryFile will include a file open operation. @@ -888,10 +888,6 @@ def fn(a): file_path, encoding="utf-8" ) as fd: # encoding file to UTF-8 for Windows. lines = fd.read() - fd.close() - os.remove( - file_path - ) # Delete temp file manually, due to setup NamedTemporaryFile as delete=False. orig_maxDiff = unittest.TestCase.maxDiff unittest.TestCase.maxDiff = None try: diff --git a/test/export/test_nativert.py b/test/export/test_nativert.py index 20f61ad03fffb..6a40c98638901 100644 --- a/test/export/test_nativert.py +++ b/test/export/test_nativert.py @@ -2,7 +2,6 @@ import copy -import pathlib import tempfile import unittest @@ -97,55 +96,55 @@ def run_with_nativert(ep): MODEL_NAME = "forward" # TODO Does named tempfile have collision? - with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: torch.export.pt2_archive._package.package_pt2( f, exported_programs={MODEL_NAME: ep_infer} ) filename = f.name - try: - ep_args, ep_kwargs = ep_infer.example_inputs - ep_args_copied, ep_kwargs_copied = ( - copy.deepcopy(ep_args), - copy.deepcopy(ep_kwargs), - ) - torch.manual_seed(0) try: - flat_expected = pytree.tree_leaves( - ep_infer.module()(*ep_args_copied, **ep_kwargs_copied) - ) - except Exception as e: - raise unittest.case.SkipTest(str(e)) from e - - model_runner = PyModelRunner(filename, MODEL_NAME) - torch.manual_seed(0) - if _is_supported_types((ep_args, ep_kwargs)): - results = model_runner.run(*ep_args, **ep_kwargs) - else: - results = model_runner.run_with_flat_inputs_and_outputs( - *pytree.tree_leaves((ep_args, ep_kwargs)) + ep_args, ep_kwargs = ep_infer.example_inputs + ep_args_copied, ep_kwargs_copied = ( + copy.deepcopy(ep_args), + copy.deepcopy(ep_kwargs), ) - flat_results = pytree.tree_leaves(results) - assert len(flat_results) == len(flat_expected) - for result, expected in zip(flat_results, flat_expected): - assert type(result) is type(expected) - if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): - assert result.shape == expected.shape - assert result.dtype == expected.dtype - assert result.device == expected.device - torch.testing.assert_close(result, expected, equal_nan=True) + torch.manual_seed(0) + try: + flat_expected = pytree.tree_leaves( + ep_infer.module()(*ep_args_copied, **ep_kwargs_copied) + ) + except Exception as e: + raise unittest.case.SkipTest(str(e)) from e + + model_runner = PyModelRunner(filename, MODEL_NAME) + torch.manual_seed(0) + if _is_supported_types((ep_args, ep_kwargs)): + results = model_runner.run(*ep_args, **ep_kwargs) else: - assert result == expected - except RuntimeError as e: - # User need to register pytree type on the cpp side, which - # cannot be tested in python unittest. - if "Unknown pytree node type" in str(e): - pass - else: - raise e - finally: - pathlib.Path(filename).unlink(missing_ok=True) - return ep + results = model_runner.run_with_flat_inputs_and_outputs( + *pytree.tree_leaves((ep_args, ep_kwargs)) + ) + flat_results = pytree.tree_leaves(results) + assert len(flat_results) == len(flat_expected) + for result, expected in zip(flat_results, flat_expected): + assert type(result) is type(expected) + if isinstance(result, torch.Tensor) and isinstance( + expected, torch.Tensor + ): + assert result.shape == expected.shape + assert result.dtype == expected.dtype + assert result.device == expected.device + torch.testing.assert_close(result, expected, equal_nan=True) + else: + assert result == expected + except RuntimeError as e: + # User need to register pytree type on the cpp side, which + # cannot be tested in python unittest. + if "Unknown pytree node type" in str(e): + pass + else: + raise e + return ep def mocked_nativert_export_strict(*args, **kwargs): @@ -287,7 +286,7 @@ def test_aoti(self, device, m, sample_inputs): ) # package everything needed for the NativeRT to execute the AOTI delegate - with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: package_nativert_with_aoti_delegate( f, MODEL_NAME, @@ -298,50 +297,48 @@ def test_aoti(self, device, m, sample_inputs): ) filename = f.name - try: - ep_args, ep_kwargs = aoti_delegate_ep.example_inputs - ep_args_copied, ep_kwargs_copied = ( - copy.deepcopy(ep_args), - copy.deepcopy(ep_kwargs), - ) - torch.manual_seed(0) try: - flat_expected = pytree.tree_leaves( - aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied) + ep_args, ep_kwargs = aoti_delegate_ep.example_inputs + ep_args_copied, ep_kwargs_copied = ( + copy.deepcopy(ep_args), + copy.deepcopy(ep_kwargs), ) - except Exception as e: - raise unittest.case.SkipTest(str(e)) from e - - model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}") - torch.manual_seed(0) - if _is_supported_types((ep_args, ep_kwargs)): - results = model_runner.run(*ep_args, **ep_kwargs) - else: - results = model_runner.run_with_flat_inputs_and_outputs( - *pytree.tree_leaves((ep_args, ep_kwargs)) - ) - flat_results = pytree.tree_leaves(results) - assert len(flat_results) == len(flat_expected) - for result, expected in zip(flat_results, flat_expected): - assert type(result) is type(expected) - if isinstance(result, torch.Tensor) and isinstance( - expected, torch.Tensor - ): - assert result.shape == expected.shape - assert result.dtype == expected.dtype - assert result.device == expected.device - torch.testing.assert_close(result, expected, equal_nan=True) + torch.manual_seed(0) + try: + flat_expected = pytree.tree_leaves( + aoti_delegate_ep.module()(*ep_args_copied, **ep_kwargs_copied) + ) + except Exception as e: + raise unittest.case.SkipTest(str(e)) from e + + model_runner = PyModelRunner(filename, f"{MODEL_NAME}-{BACKEND_ID}") + torch.manual_seed(0) + if _is_supported_types((ep_args, ep_kwargs)): + results = model_runner.run(*ep_args, **ep_kwargs) else: - assert result == expected - except RuntimeError as e: - # User need to register pytree type on the cpp side, which - # cannot be tested in python unittest. - if "Unknown pytree node type" in str(e): - pass - else: - raise e - finally: - pathlib.Path(filename).unlink(missing_ok=True) + results = model_runner.run_with_flat_inputs_and_outputs( + *pytree.tree_leaves((ep_args, ep_kwargs)) + ) + flat_results = pytree.tree_leaves(results) + assert len(flat_results) == len(flat_expected) + for result, expected in zip(flat_results, flat_expected): + assert type(result) is type(expected) + if isinstance(result, torch.Tensor) and isinstance( + expected, torch.Tensor + ): + assert result.shape == expected.shape + assert result.dtype == expected.dtype + assert result.device == expected.device + torch.testing.assert_close(result, expected, equal_nan=True) + else: + assert result == expected + except RuntimeError as e: + # User need to register pytree type on the cpp side, which + # cannot be tested in python unittest. + if "Unknown pytree node type" in str(e): + pass + else: + raise e if is_fbcode(): diff --git a/test/test_serialization.py b/test/test_serialization.py index d292c13993cfe..2755ae29a7ffa 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -313,15 +313,17 @@ def test_serialization_fake_zip(self): def test_serialization_gzip(self): # Test serialization with gzip file b = self._test_serialization_data() - f1 = tempfile.NamedTemporaryFile(delete=False) - f2 = tempfile.NamedTemporaryFile(delete=False) - torch.save(b, f1) - with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) - - with gzip.open(f2.name, 'rb') as f: - c = torch.load(f) - self._test_serialization_assert(b, c) + with tempfile.NamedTemporaryFile() as f1, tempfile.NamedTemporaryFile(delete=False) as f2: + torch.save(b, f1) + f1.seek(0) + with gzip.open(f2.name, 'wb') as f_out: + shutil.copyfileobj(f1, f_out) + + with gzip.open(f2.name, 'rb') as f: + c = torch.load(f) + self._test_serialization_assert(b, c) + f2.close() + os.unlink(f2.name) @unittest.skipIf( not TEST_DILL or HAS_DILL_AT_LEAST_0_3_1, @@ -382,19 +384,19 @@ def test_serialization_dill(self): def test_serialization_offset_gzip(self): a = torch.randn(5, 5) i = 41 - f1 = tempfile.NamedTemporaryFile(delete=False) f2 = tempfile.NamedTemporaryFile(delete=False) - with open(f1.name, 'wb') as f: - pickle.dump(i, f) - torch.save(a, f) - with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) - - with gzip.open(f2.name, 'rb') as f: - j = pickle.load(f) - b = torch.load(f) - self.assertTrue(torch.equal(a, b)) - self.assertEqual(i, j) + with tempfile.NamedTemporaryFile() as f1: + pickle.dump(i, f1) + torch.save(a, f1) + f1.seek(0) + with gzip.open(f2.name, 'wb') as f_out: + shutil.copyfileobj(f1, f_out) + + with gzip.open(f2.name, 'rb') as f: + j = pickle.load(f) + b = torch.load(f) + self.assertTrue(torch.equal(a, b)) + self.assertEqual(i, j) def _test_serialization_sparse(self, weights_only): def _test_serialization(conversion): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 11be081db1be7..05b1b9bd33a8d 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2651,7 +2651,6 @@ def pass_execution_and_save( with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", - delete=False, ) as f: before_io = io.StringIO() after_io = io.StringIO() diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 645667bb81bb6..6aa4383a58cdd 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -271,13 +271,11 @@ def export_chrome_trace(self, path: str): "Profiler must be initialized before exporting chrome trace" ) if path.endswith(".gz"): - with tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) as fp: - fp.close() + with tempfile.NamedTemporaryFile("w+b", suffix=".json") as fp: retvalue = self.profiler.export_chrome_trace(fp.name) - with open(fp.name, "rb") as fin: - with gzip.open(path, "wb") as fout: - fout.writelines(fin) - os.remove(fp.name) + fp.seek(0) + with gzip.open(path, "wb") as fout: + fout.writelines(fp) return retvalue else: return self.profiler.export_chrome_trace(path) @@ -448,15 +446,14 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non if path.endswith(".html"): self.mem_tl.export_memory_timeline_html(path, device) elif path.endswith(".gz"): - fp = tempfile.NamedTemporaryFile("w+t", suffix=".json", delete=False) - fp.close() - if path.endswith("raw.json.gz"): - self.mem_tl.export_memory_timeline_raw(fp.name, device) - else: - self.mem_tl.export_memory_timeline(fp.name, device) - with open(fp.name) as fin, gzip.open(path, "wt") as fout: - fout.writelines(fin) - os.remove(fp.name) + with tempfile.NamedTemporaryFile("w+t", suffix=".json") as fp: + fp.close() + if path.endswith("raw.json.gz"): + self.mem_tl.export_memory_timeline_raw(fp.name, device) + else: + self.mem_tl.export_memory_timeline(fp.name, device) + with open(fp.name) as fin, gzip.open(path, "wt") as fout: + fout.writelines(fin) else: self.mem_tl.export_memory_timeline(path, device) @@ -946,7 +943,7 @@ def build_execution_trace_obs_from_env() -> Optional["ExecutionTraceObserver"]: """ if os.environ.get("ENABLE_PYTORCH_EXECUTION_TRACE", "0") == "1": try: - fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) + fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) # noqa:SIM115 except Exception as e: warn( f"Execution trace will not be recorded. Exception on creating default temporary file: {e}", diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 503e15af4bb3e..478d3c978120b 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -215,19 +215,16 @@ def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False): def get_profiler_nccl_meta(prof): """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms" We will need to test metadata obtained from profiler here""" - tf = tempfile.NamedTemporaryFile(mode="w+t", suffix=".json", delete=False) - tf.close() - trace_file = tf.name + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".json") as tf: + tf.close() + trace_file = tf.name - prof.export_chrome_trace(trace_file) - with open(trace_file) as f: - events = json.load(f)["traceEvents"] - print(f"Trace saved to {trace_file}") + prof.export_chrome_trace(trace_file) + with open(trace_file) as f: + events = json.load(f)["traceEvents"] + print(f"Trace saved to {trace_file}") - # Comment to debug - os.remove(trace_file) - - return [e for e in events if e.get("name") == "record_param_comms"] + return [e for e in events if e.get("name") == "record_param_comms"] # Base error message substring on unfinished reductions. diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index f38363f6dea89..17ecea8bbb559 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -607,8 +607,7 @@ def _invoke( def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]: # https://thraxil.org/users/anders/posts/2008/03/13/Subprocess-Hanging-PIPE-is-your-enemy/ - f_stdout_stderr = open(stdout_stderr_log, "wb") - try: + with open(stdout_stderr_log, "wb") as f_stdout_stderr: invocation = subprocess.run( args, stdout=f_stdout_stderr, @@ -617,8 +616,6 @@ def run(args: list[str], **kwargs: Any) -> tuple[CompletedProcessType, str]: ) with open(stdout_stderr_log) as f: return invocation, f.read() - finally: - f_stdout_stderr.close() try: if is_python: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 35e70d686a34f..1f8f0d70c9c2f 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1321,7 +1321,7 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # test. # See NOTE [ DataLoader on Linux and open files limit ] fds_limit_margin = 10 - [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + [tempfile.NamedTemporaryFile() for _ in range(fds_limit_margin)] # noqa: SIM115 except OSError as e: if e.errno == errno.EMFILE: raise RuntimeError( diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 6032de7166af7..4fcc617b3b722 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -232,7 +232,7 @@ def get_file_binaries_from_pathnames( raise TypeError( f"Expected string type for pathname, but got {type(pathname)}" ) - yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) + yield pathname, StreamWrapper(open(pathname, mode, encoding=encoding)) # noqa:SIM115 def validate_pathname_binary_tuple(data: tuple[str, IOBase]) -> None: diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 41469a2855421..3e538ddc4c02d 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import json import logging -import os import struct from typing import Any @@ -695,27 +694,23 @@ def make_video(tensor, fps): # encode sequence of images into gif string clip = mpy.ImageSequenceClip(list(tensor), fps=fps) - filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name - try: # newer version of moviepy use logger instead of progress_bar argument. - clip.write_gif(filename, verbose=False, logger=None) - except TypeError: - try: # older version of moviepy does not support progress_bar argument. - clip.write_gif(filename, verbose=False, progress_bar=False) + with tempfile.NamedTemporaryFile(suffix=".gif") as f: + filename = f.name + try: # newer version of moviepy use logger instead of progress_bar argument. + clip.write_gif(filename, verbose=False, logger=None) except TypeError: - clip.write_gif(filename, verbose=False) + try: # older version of moviepy does not support progress_bar argument. + clip.write_gif(filename, verbose=False, progress_bar=False) + except TypeError: + clip.write_gif(filename, verbose=False) - with open(filename, "rb") as f: + f.seek(0) tensor_string = f.read() - try: - os.remove(filename) - except OSError: - logger.warning("The temporary file used by moviepy cannot be deleted.") - - # pyrefly: ignore [missing-attribute] - return Summary.Image( - height=h, width=w, colorspace=c, encoded_image_string=tensor_string - ) + # pyrefly: ignore [missing-attribute] + return Summary.Image( + height=h, width=w, colorspace=c, encoded_image_string=tensor_string + ) def audio(tag, tensor, sample_rate=44100): diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 43f3410fe5e46..8abb547d500f8 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -498,7 +498,7 @@ def warn_tensor_cycles(): logger.info("Watching Python reference cycles for CUDA Tensors.") def write_and_log(html) -> None: - with NamedTemporaryFile('w', suffix='.html', delete=False) as f: + with NamedTemporaryFile('w', suffix='.html') as f: f.write(html) logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name) return observe_tensor_cycles(write_and_log) From afb014541bd2050f888d9fc314c274fb5452f641 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 9 Nov 2025 23:13:56 +0000 Subject: [PATCH 269/651] Separately handle null data_ptr storages when creating unique ID (#167405) ## Summary Previously fake/functionalized tensors that have `null` storage_ptr could segfault when checking for `.expired()` on weak storage ref, so handle `nullptr` storages separately, without checking their weakrefs. Diagnosis and PR created by codex ------ [Codex Task](https://chatgpt.com/codex/tasks/task_e_690ea8790054832f90eaffb37ee0d8c8) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167405 Approved by: https://github.com/Skylion007 --- .../standalone/execution_trace_observer.cpp | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 918cc554c5b16..5edc59c893d7a 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -122,29 +122,47 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT ID get_tensor_storage_ID(const c10::Storage& t_storage) { const std::lock_guard lock(gMutex); - const void* raw_data_ptr = t_storage.data(); - auto iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); - if (iter == data_ptr_to_weak_storage_ptr.end()) { + const void* raw_data_ptr = nullptr; + bool should_track_liveness = false; + // FakeTensor/FunctionalTensor may clear the Storage handle entirely or use + // a nullptr data pointer. Treat both cases as a shared cache key but avoid + // touching the weak-ref table so they can reuse the same ID without + // tripping the liveness check. + if (t_storage.unsafeGetStorageImpl()) { + raw_data_ptr = t_storage.data(); + should_track_liveness = raw_data_ptr != nullptr; + } + + auto id_iter = data_ptr_to_storage_id.find(raw_data_ptr); + if (!should_track_liveness) { + if (id_iter != data_ptr_to_storage_id.end()) { + return id_iter->second; + } ID id = storage_id_++; data_ptr_to_storage_id.emplace(raw_data_ptr, id); + return id; + } + + auto weak_iter = data_ptr_to_weak_storage_ptr.find(raw_data_ptr); + if (weak_iter == data_ptr_to_weak_storage_ptr.end()) { + ID id = storage_id_++; + data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); data_ptr_to_weak_storage_ptr.emplace( raw_data_ptr, t_storage.getWeakStorageImpl()); return id; - } else { - // check if the storage is still alive - if (iter->second.expired()) { - ID id = storage_id_++; - // std::unorder_map does not change if the key is already in the map. - // So we need to remove the key and insert the key with the new value. - data_ptr_to_storage_id.erase(raw_data_ptr); - data_ptr_to_storage_id[raw_data_ptr] = id; - data_ptr_to_weak_storage_ptr.insert_or_assign( - raw_data_ptr, t_storage.getWeakStorageImpl()); - return id; - } else { - return data_ptr_to_storage_id[raw_data_ptr]; - } } + + if (weak_iter->second.expired()) { + ID id = storage_id_++; + data_ptr_to_storage_id.insert_or_assign(raw_data_ptr, id); + data_ptr_to_weak_storage_ptr.insert_or_assign( + raw_data_ptr, t_storage.getWeakStorageImpl()); + return id; + } + + id_iter = data_ptr_to_storage_id.find(raw_data_ptr); + TORCH_INTERNAL_ASSERT(id_iter != data_ptr_to_storage_id.end()); + return id_iter->second; } // Observer run state. From a4c7856112fd303838d584f41d99ff202246d4a1 Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Mon, 10 Nov 2025 00:29:07 +0000 Subject: [PATCH 270/651] [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167340) Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs. Test Plan: Inductor test (fbcode): `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"` Tritonbench (fbcode): `clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Tritonbench(oss): `clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Unit Tests(oss): `clear; python test/inductor/test_cutedsl_grouped_mm.py` Differential Revision: D86537373 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167340 Approved by: https://github.com/jananisriram --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 + setup.py | 33 ++ test/inductor/test_cutedsl_grouped_mm.py | 154 ++++++++ torch/_inductor/config.py | 4 + torch/_inductor/kernel/mm_common.py | 7 + torch/_inductor/kernel/mm_grouped.py | 90 +++-- .../templates/cutedsl_mm_grouped.py.jinja | 333 ++++++++++++++++++ .../_inductor/template_heuristics/cutedsl.py | 141 ++++++++ torch/_inductor/utils.py | 78 ++++ 10 files changed, 810 insertions(+), 33 deletions(-) create mode 100644 test/inductor/test_cutedsl_grouped_mm.py create mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja create mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 37adb0282c999..ffd7e55d2337b 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index d1b3b17445dac..3b4323051073a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py +torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index 31e78d0245d93..99af489cc2114 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +def mirror_inductor_external_kernels() -> None: + """ + Copy external kernels into Inductor so they are importable. + """ + paths = [ + ( + CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", + CWD + / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", + ), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + + # Copy the files from the orig location to the new location + if orig_path.is_file(): + shutil.copyfile(orig_path, new_path) + continue + if orig_path.is_dir(): + if new_path.exists(): + # copytree fails if the tree exists already, so remove it. + shutil.rmtree(new_path) + shutil.copytree(orig_path, new_path) + continue + raise RuntimeError( + "Check the file paths in `mirror_inductor_external_kernels()`" + ) + + # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1615,6 +1646,7 @@ def main() -> None: mirror_files_into_torchgen() if RUN_BUILD_DEPS: build_deps() + mirror_inductor_external_kernels() ( ext_modules, @@ -1649,6 +1681,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", + "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py new file mode 100644 index 0000000000000..c26def3a54099 --- /dev/null +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -0,0 +1,154 @@ +# Owner(s): ["module: inductor"] + + +import unittest + +import torch +from torch import Tensor +from torch._inductor import config +from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import ensure_cute_available +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +@unittest.skipIf( + not (ensure_cute_available() and is_datacenter_blackwell_arch()), + "CuTeDSL library or Blackwell device not available", +) +@instantiate_parametrized_tests +class TestCuTeDSLGroupedGemm(InductorTestCase): + def _get_inputs( + self, + group_size: int, + M_hint: int, + K: int, + N: int, + device: str, + dtype: torch.dtype, + alignment: int = 16, + ) -> tuple[Tensor, Tensor, Tensor]: + # --- Random, tile-aligned M sizes --- + M_sizes = ( + torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) + * alignment + ) + + M_total = torch.sum(M_sizes).item() + + # --- Construct input tensors --- + A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 + B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 + + # --- Build offsets (no leading zero, strictly increasing) --- + offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) + + return (A, B, offsets) + + @parametrize("group_size", (2, 8)) + @parametrize("M_hint", (256, 1024)) + @parametrize("K", (64, 128)) + @parametrize("N", (128, 256)) + def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): + device = "cuda" + dtype = torch.bfloat16 + + A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # Eager execution + c_eager = grouped_gemm_fn(A, B, offsets) + + # Test with Cute backend + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) + @parametrize("layout_B", ("contiguous", "broadcasted")) + def test_grouped_gemm_assorted_layouts( + self, + layout_A: str, + layout_B: str, + ): + device = "cuda" + dtype = torch.bfloat16 + + G, K, N = 8, 64, 128 + M_sizes = [128] * G + sum_M = sum(M_sizes) + offsets = torch.tensor( + [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device + ) + + A_base = torch.randn(sum_M, K, device=device, dtype=dtype) + A = A_base + + if layout_A == "offset": + # allocate bigger buffer than needed, use nonzero storage offset + storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) + offset = 128 # skip first 128 elements + A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) + elif layout_A == "padded": + # simulate row pitch > K (row_stride = K + pad) + row_pitch = K + 8 + storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) + A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) + elif layout_A == "view": + A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) + A = A_storage.view(sum_M, K) + assert A._base is not None + assert A.shape == (sum_M, K) + + B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 + + if layout_B == "broadcasted": + # Broadcast B across groups (zero stride along G) + B = B[0].expand(G, K, N) + assert B.stride(0) == 0 + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # --- eager --- + c_eager = grouped_gemm_fn(A, B, offsets) + + # --- compiled (CUTE backend) --- + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bd0ff91616b37..094850eced44c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -550,6 +550,10 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b95073e769f31..eb22b95af2afc 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence +from functools import partial +from pathlib import Path from typing import Any import torch @@ -12,6 +14,7 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template log = logging.getLogger(__name__) @@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 881c14fd43d0d..c81ec607661bc 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,13 @@ # mypy: allow-untyped-defs import logging -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl +from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -22,11 +24,13 @@ get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, + load_kernel_template, persistent_grouped_mm_grid, ) @@ -513,6 +517,11 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + def grouped_mm_args( mat1: TensorBox, @@ -714,43 +723,44 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -788,6 +798,22 @@ def _tuned_grouped_mm_common( **config.kwargs, ) + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000..989f297c5f80f --- /dev/null +++ b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000..db337b9d8a271 --- /dev/null +++ b/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 05b1b9bd33a8d..f98d3385b1846 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1911,6 +1911,84 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From abf31db2cc039ee299337bad6f7f11577c877481 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 7 Nov 2025 17:11:31 +0000 Subject: [PATCH 271/651] Introduce a new API torch.accelerator.get_memory_info (#156812) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Motivation `torch.cuda.mem_get_info` and `torch.xpu.mem_get_info` are widely used in other popular repos, such as - https://github.com/sgl-project/sglang/blob/076313bd099ac1ee484ee77009eaae864eacf396/python/sglang/srt/utils.py#L378, - https://github.com/huggingface/accelerate/blob/7ecc2d7f394fc0686062a18d46128a8bd97c7dad/src/accelerate/utils/modeling.py#L822, - https://github.com/vllm-project/vllm/blob/7ba34b1241ada58f8212f350a8b17382cb412cf2/vllm/worker/worker.py#L150. - This PR introduces a unified API `torch.accelerator.get_memory_info` to cover this scenario. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156812 Approved by: https://github.com/albanD --- aten/src/ATen/DeviceAccelerator.h | 5 +++++ c10/core/CachingDeviceAllocator.h | 4 ++++ c10/cuda/CUDACachingAllocator.h | 7 +++++++ c10/xpu/XPUCachingAllocator.cpp | 4 ++++ docs/source/accelerator.md | 1 + torch/_C/__init__.pyi.in | 1 + torch/accelerator/__init__.py | 4 +++- torch/accelerator/memory.py | 34 +++++++++++++++++++++++++++++++ torch/csrc/DeviceAccelerator.cpp | 7 +++++++ 9 files changed, 66 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index f23b35047fcc8..2cc4cff7cd1f2 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { at::getDeviceAllocator(device_type)->resetPeakStats(device_index); } +TORCH_API inline std::pair getMemoryInfo( + c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index); +} } // namespace at::accelerator namespace at { diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 0bec03ae417fa..b69ef1bad5f08 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator { // Resets peak memory usage statistics for the specified device virtual void resetPeakStats(c10::DeviceIndex device) = 0; + + // Return the free memory size and total memory size in bytes for the + // specified device. + virtual std::pair getMemoryInfo(c10::DeviceIndex device) = 0; }; // This function is used to get the DeviceAllocator for a specific device type diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 8fee00dd621dc..29f0452e8a22d 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator { c10::DeviceIndex device, std::shared_ptr pps) = 0; virtual std::string name() = 0; + std::pair getMemoryInfo(c10::DeviceIndex device) override { + c10::DeviceGuard device_guard({at::kCUDA, device}); + size_t free = 0; + size_t total = 0; + C10_CUDA_CHECK(cudaMemGetInfo(&free, &total)); + return {free, total}; + } }; // Allocator object, statically initialized diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 17a2669e7290f..8c0eb7e18dcd2 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1240,6 +1240,10 @@ class XPUAllocator : public DeviceAllocator { c10::xpu::get_raw_device(dev_to_access)); } + std::pair getMemoryInfo(DeviceIndex device) override { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented yet."); + } + double getMemoryFraction(DeviceIndex device) { assertValidDevice(device); return device_allocators[device]->getMemoryFraction(); diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index ce593a9acf518..c5904563ee711 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -40,6 +40,7 @@ :nosignatures: empty_cache + get_memory_info max_memory_allocated max_memory_reserved memory_allocated diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 559230350bcc9..3fdf6302115b6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2491,6 +2491,7 @@ def _accelerator_emptyCache() -> None: ... def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... def _accelerator_resetPeakStats(device_index: _int) -> None: ... +def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... # Defined in torch/csrc/jit/python/python_tracer.cpp diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 4d1a78df1f74c..f07e36fe9ad0b 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -10,6 +10,7 @@ from ._utils import _device_t, _get_device_index from .memory import ( empty_cache, + get_memory_info, max_memory_allocated, max_memory_reserved, memory_allocated, @@ -25,9 +26,10 @@ "current_device_idx", # deprecated "current_device_index", "current_stream", - "empty_cache", "device_count", "device_index", + "empty_cache", + "get_memory_info", "is_available", "max_memory_allocated", "max_memory_reserved", diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py index d98be36321119..513e497f3883c 100644 --- a/torch/accelerator/memory.py +++ b/torch/accelerator/memory.py @@ -8,6 +8,7 @@ __all__ = [ "empty_cache", + "get_memory_info", "max_memory_allocated", "max_memory_reserved", "memory_allocated", @@ -87,6 +88,9 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: If not given, use :func:`torch.accelerator.current_device_index` by default. If a :class:`torch.device` or str is provided, its type must match the current :ref:`accelerator` device type. + + Returns: + OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values. """ if not torch._C._accelerator_isAllocatorInitialized(): return OrderedDict() @@ -117,6 +121,9 @@ def memory_allocated(device_index: _device_t = None, /) -> int: If not given, use :func:`torch.accelerator.current_device_index` by default. If a :class:`torch.device` or str is provided, its type must match the current :ref:`accelerator` device type. + + Returns: + int: the current memory occupied by live tensors (in bytes) within the current process. """ return memory_stats(device_index).get("allocated_bytes.all.current", 0) @@ -134,6 +141,9 @@ def max_memory_allocated(device_index: _device_t = None, /) -> int: If not given, use :func:`torch.accelerator.current_device_index` by default. If a :class:`torch.device` or str is provided, its type must match the current :ref:`accelerator` device type. + + Returns: + int: the peak memory occupied by live tensors (in bytes) within the current process. """ return memory_stats(device_index).get("allocated_bytes.all.peak", 0) @@ -147,6 +157,9 @@ def memory_reserved(device_index: _device_t = None, /) -> int: If not given, use :func:`torch.accelerator.current_device_index` by default. If a :class:`torch.device` or str is provided, its type must match the current :ref:`accelerator` device type. + + Returns: + int: the current memory reserved by PyTorch (in bytes) within the current process. """ return memory_stats(device_index).get("reserved_bytes.all.current", 0) @@ -164,6 +177,9 @@ def max_memory_reserved(device_index: _device_t = None, /) -> int: If not given, use :func:`torch.accelerator.current_device_index` by default. If a :class:`torch.device` or str is provided, its type must match the current :ref:`accelerator` device type. + + Returns: + int: the peak memory reserved by PyTorch (in bytes) within the current process. """ return memory_stats(device_index).get("reserved_bytes.all.peak", 0) @@ -200,3 +216,21 @@ def reset_peak_memory_stats(device_index: _device_t = None, /) -> None: """ device_index = _get_device_index(device_index, optional=True) return torch._C._accelerator_resetPeakStats(device_index) + + +def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]: + r"""Return the current device memory information for a given device index. + + Args: + device_index (:class:`torch.device`, str, int, optional): the index of the device to target. + If not given, use :func:`torch.accelerator.current_device_index` by default. + If a :class:`torch.device` or str is provided, its type must match the current + :ref:`accelerator` device type. + + Returns: + tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes. + The first value is the free memory on the device (available across all processes and applications), + The second value is the device's total hardware memory capacity. + """ + device_index = _get_device_index(device_index, optional=True) + return torch._C._accelerator_getMemoryInfo(device_index) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index b6176f11aaf6e..14e54851178f5 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -138,6 +138,13 @@ void initModule(PyObject* module) { at::accelerator::resetPeakStats(device_index); }); + m.def("_accelerator_getMemoryInfo", [](c10::DeviceIndex device_index) { + const auto device_type = at::accelerator::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + py::gil_scoped_release no_gil; + return at::accelerator::getMemoryInfo(device_index); + }); + m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); From fe6615e3977c395f12ebd139da0969fd73c659ec Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Sun, 9 Nov 2025 13:20:03 -0800 Subject: [PATCH 272/651] Swap pallas test shard to 12.8 (#167428) Getting some weird failures building cuda13, lets stick to what we know works Pull Request resolved: https://github.com/pytorch/pytorch/pull/167428 Approved by: https://github.com/jansel --- .ci/docker/build.sh | 4 ++-- .github/workflows/docker-builds.yml | 2 +- .github/workflows/inductor-unittest.yml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 90c87b55ea416..203ab597a75bc 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -260,8 +260,8 @@ case "$tag" in HALIDE=yes TRITON=yes ;; - pytorch-linux-jammy-cuda13.0-py3.12-pallas) - CUDA_VERSION=13.0.0 + pytorch-linux-jammy-cuda12.8-py3.12-pallas) + CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 PALLAS=yes diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index f632d4a858abb..66f3bc21755c4 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -67,7 +67,7 @@ jobs: pytorch-linux-jammy-py3.10-gcc11, pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, - pytorch-linux-jammy-cuda13.0-py3.12-pallas, + pytorch-linux-jammy-cuda12.8-py3.12-pallas, pytorch-linux-jammy-xpu-n-1-py3, pytorch-linux-noble-xpu-n-py3, pytorch-linux-noble-xpu-n-py3-inductor-benchmarks, diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index af9829c96f506..ca9b57cab2ddb 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -86,8 +86,8 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - build-environment: linux-jammy-py3.12-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas + build-environment: linux-jammy-cuda12.8-py3.12-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas cuda-arch-list: '8.9' runner: linux.8xlarge.memory runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" From 2c78080ec00f2ea8b7fe5072ced4940b17ca7789 Mon Sep 17 00:00:00 2001 From: Slawomir Siwek Date: Mon, 10 Nov 2025 03:10:22 +0000 Subject: [PATCH 273/651] Register functorch XPU/HPU dispatch keys (#167095) Fixes TestOperatorsXPU.test_data_write_errors_under_transform_xpu https://github.com/intel/torch-xpu-ops/issues/2237 Tests on other devices throw runtime error "_mutating directly with `.data` inside functorch transform is not allowed._", but XPU/HPU fails earlier on `_has_compatible_shallow_copy_type`. This check is not met only when calling tensor.data inside functorch call. ```cpp bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) { return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type( from.key_set()); } ``` ### t.data | Tensor | Device | Dispatch Keys | |--------|---------|---------------| | `self` | `xpu` | `XPU, ADInplaceOrView, AutogradXPU, AutocastXPU` | | `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU` | ### t.data inside functorch transform | Tensor | Device | Dispatch Keys | |--------|---------|---------------| | `self` | `xpu` | `ADInplaceOrView, AutogradOther, FuncTorchGradWrapper` | | `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU, FuncTorchGradWrapper` | ### t.data inside functorch transform + XPU dispatch key | Tensor | Device | Dispatch Keys | |--------|---------|---------------| | `self` | `xpu` | `XPU, ADInplaceOrView, AutogradXPU, AutocastXPU, FuncTorchGradWrapper` | | `from` | `cpu` | `CPU, ADInplaceOrView, AutogradCPU, AutocastCPU, FuncTorchGradWrapper` | Pull Request resolved: https://github.com/pytorch/pytorch/pull/167095 Approved by: https://github.com/guangyey, https://github.com/albanD --- aten/src/ATen/functorch/BatchedTensorImpl.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index 985b289b3fe02..14be24d63e65a 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ DispatchKey::Negative, DispatchKey::Conjugate, DispatchKey::XLA, + DispatchKey::XPU, + DispatchKey::HPU, DispatchKey::CUDA, DispatchKey::CPU, DispatchKey::PrivateUse1, From a058bbdd6fd34936fb6e61feaf827a5068c3e109 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Mon, 10 Nov 2025 04:02:56 +0000 Subject: [PATCH 274/651] [xpu][test] Enable profiler test for XPU (#165423) Fixes #165130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165423 Approved by: https://github.com/EikanWang, https://github.com/atalman, https://github.com/mlazos --- .ci/pytorch/build.sh | 6 ++++-- .ci/pytorch/test.sh | 2 ++ test/inductor/test_codecache.py | 4 ++++ test/profiler/test_execution_trace.py | 4 ++-- test/profiler/test_profiler.py | 10 +++++++++- test/run_test.py | 6 ------ test/test_xpu.py | 3 ++- 7 files changed, 23 insertions(+), 12 deletions(-) diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index d66aa1120fb30..071f14700def4 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -168,14 +168,16 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then # shellcheck disable=SC1091 source /opt/intel/oneapi/compiler/latest/env/vars.sh # shellcheck disable=SC1091 + source /opt/intel/oneapi/umf/latest/env/vars.sh + # shellcheck disable=SC1091 source /opt/intel/oneapi/ccl/latest/env/vars.sh # shellcheck disable=SC1091 source /opt/intel/oneapi/mpi/latest/env/vars.sh + # shellcheck disable=SC1091 + source /opt/intel/oneapi/pti/latest/env/vars.sh # Enable XCCL build export USE_XCCL=1 export USE_MPI=0 - # XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA - export USE_KINETO=0 export TORCH_XPU_ARCH_LIST=pvc fi diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index ffd7e55d2337b..821a714553b49 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -208,6 +208,8 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then source /opt/intel/oneapi/ccl/latest/env/vars.sh # shellcheck disable=SC1091 source /opt/intel/oneapi/mpi/latest/env/vars.sh + # shellcheck disable=SC1091 + source /opt/intel/oneapi/pti/latest/env/vars.sh # Check XPU status before testing timeout 30 xpu-smi discovery || true fi diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 46f1ca031bf83..4b9030b5cae4b 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -206,6 +206,10 @@ def f(x): .decode() .strip() ) + # XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261 + if torch.xpu.is_available(): + wrapper_path = wrapper_path.splitlines()[-1] + hit = hit.splitlines()[-1] self.assertEqual(hit, "1") with open(wrapper_path) as f: diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 3a174b1d66a67..dbd5d89ad6a61 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -482,8 +482,8 @@ def fn(a, b, c): @unittest.skipIf(IS_WINDOWS, "torch.compile does not support WINDOWS") @unittest.skipIf( - (not has_triton()) or (not TEST_CUDA and not TEST_XPU), - "need triton and device(CUDA or XPU) availability to run", + (not has_triton()) or (not TEST_CUDA), + "need triton and device CUDA availability to run", ) @skipCPUIf(True, "skip CPU device for testing profiling triton") def test_triton_fx_graph_with_et(self, device): diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 43216274f9271..831f99aafff0a 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -2005,6 +2005,10 @@ def _test_chrome_trace_basic_helper(self, with_cuda=False): report = json.load(f) self._validate_basic_json(report["traceEvents"], with_cuda) + @unittest.skipIf( + torch.xpu.is_available(), + "XPU Trace event ends too late! Refer https://github.com/intel/torch-xpu-ops/issues/2263", + ) @unittest.skipIf(not kineto_available(), "Kineto is required") @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_basic_chrome_trace(self): @@ -2158,7 +2162,10 @@ def test_user_annotation(self): @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_basic_profile(self): # test a really basic profile to make sure no erroneous aten ops are run - x = torch.randn(4, device="cuda") + acc = torch.accelerator.current_accelerator() + self.assertIsNotNone(acc) + device = acc.type + x = torch.randn(4, device=device) with torch.profiler.profile(with_stack=True) as p: x *= 2 names = [e.name for e in p.events()] @@ -2225,6 +2232,7 @@ def test_lazy_build_tree(self): @unittest.skipIf( torch.cuda.is_available(), "CUDA complains about forking after init" ) + @unittest.skipIf(torch.xpu.is_available(), "XPU complains about forking after init") @unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows") def test_forked_process(self): # Induce a pid cache by running the profiler with payload diff --git a/test/run_test.py b/test/run_test.py index 2abf324ad43d6..7ce37a7514604 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -263,13 +263,7 @@ def __contains__(self, item): XPU_BLOCKLIST = [ "test_autograd", - "profiler/test_cpp_thread", - "profiler/test_execution_trace", "profiler/test_memory_profiler", - "profiler/test_profiler", - "profiler/test_profiler_tree", - "profiler/test_record_function", - "profiler/test_torch_tidy", "test_openreg", ] diff --git a/test/test_xpu.py b/test/test_xpu.py index 0e60842605396..6b92dc4c96b38 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -206,7 +206,8 @@ def test_multi_process(model, input): test_multi_process(model, input) print(torch.xpu.device_count()) """ - rc = check_output(test_script) + # XPU have extra lines, so get the last line, refer https://github.com/intel/torch-xpu-ops/issues/2261 + rc = check_output(test_script).splitlines()[-1] self.assertEqual(rc, str(torch.xpu.device_count())) def test_streams(self): From e545ba2d3428b8d793c9869e5abc1eed7c7fbbdb Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 7 Nov 2025 18:50:25 -0800 Subject: [PATCH 275/651] [DTensor] Fix Conv behavior for replicate stategy (#167402) Pass `dim_map` to `_requires_data_exchange` and return False if both spatial and channels dimensions are replicated Modify `test_conv1d` and `test_conv3d` to check values rather than just shape, and replicate `conv3d` across batch dimension In general, feels like current Convolution implementation was written to work only if tensor is sharded across last dimention Pull Request resolved: https://github.com/pytorch/pytorch/pull/167402 Approved by: https://github.com/ezyang --- .../tensor/test_convolution_ops.py | 16 +++++------ torch/distributed/tensor/_tp_conv.py | 28 +++++++++++++------ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index de4343eef6a4e..6c3623f84a4ee 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -204,14 +204,16 @@ def test_conv_backward_none_grad_inp(self): self.assertTrue(b_dt.grad is not None) self.assertTrue(x_dt.grad is None) - def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]: + def _run_single_arg_fwd( + self, model, arg, placements=None + ) -> tuple[torch.Tensor, torch.Tensor]: """Given model and arg, runs fwd model local and distbuted given device_mesh""" device_mesh = self.build_device_mesh() model_copy = copy.deepcopy(model).to(device=self.device_type) dist_model = distribute_module(model, device_mesh, _conv_fn) - arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()]) + arg_dt = DTensor.from_local(arg, device_mesh, placements) out_dt = dist_model(arg_dt.to(device=self.device_type)) - out = model_copy(arg) + out = model_copy(arg_dt.full_tensor()) return (out_dt.full_tensor(), out) @with_comms @@ -219,22 +221,20 @@ def test_conv1d(self): model = nn.Conv1d(64, 64, 3, padding=1) x = torch.randn(1, 64, 8, device=self.device_type) out_dt, out = self._run_single_arg_fwd(model, x) - self.assertEqual(out_dt.shape, out.shape) + self.assertEqual(out_dt, out) @with_comms def test_conv3d(self): model = nn.Conv3d(64, 64, 3, padding=1) x = torch.randn(1, 64, 8, 8, 8, device=self.device_type) - out_dt, out = self._run_single_arg_fwd(model, x) - self.assertEqual(out_dt.shape, out.shape) + out_dt, out = self._run_single_arg_fwd(model, x, [Shard(0)]) + self.assertEqual(out_dt, out) DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class( DistConvolutionOpsTest, # Send / recv ops are not supported skipped_tests=[ - "test_conv1d", - "test_conv3d", "test_conv_backward_none_grad_inp", "test_depthwise_convolution", "test_downsampling_convolution", diff --git a/torch/distributed/tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py index 2b3f126c7e506..275cb07934b50 100644 --- a/torch/distributed/tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -11,7 +11,10 @@ aten = torch.ops.aten -def _requires_data_exchange(padding): +def _requires_data_exchange(padding, dim_map) -> bool: + # Data exchange is not need if only sharded across batch dim + if all(x == -1 for x in dim_map[1:]): + return False # TODO: whether there requires data exchange is currently determined by padding return padding[-1] != 0 @@ -107,6 +110,7 @@ def tp_convolution( op_call: torch._ops.OpOverload, local_tensor_args: tuple[object, ...], local_tensor_kwargs: dict[str, object], + dim_map: list[int], ) -> object: assert op_call == aten.convolution.default assert len(local_tensor_args) == 9 @@ -120,7 +124,7 @@ def tp_convolution( assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) assert isinstance(padding, list) - if not _requires_data_exchange(padding): + if not _requires_data_exchange(padding, dim_map): local_results = op_call(*local_tensor_args, **local_tensor_kwargs) return local_results else: @@ -160,6 +164,7 @@ def tp_convolution_backward( op_call: torch._ops.OpOverload, local_tensor_args: tuple[object, ...], local_tensor_kwargs: dict[str, object], + dim_map: list[int], ) -> object: assert op_call == aten.convolution_backward.default assert len(local_tensor_args) == 11 @@ -174,7 +179,7 @@ def tp_convolution_backward( assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) assert isinstance(padding, list) - if not _requires_data_exchange(padding): + if not _requires_data_exchange(padding, dim_map): local_results = op_call(*local_tensor_args, **local_tensor_kwargs) return local_results else: @@ -239,15 +244,18 @@ def convolution_handler( dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) output_sharding = op_info.output_sharding assert output_sharding is not None, "output sharding should not be None" + output_spec = output_sharding.output_spec + assert isinstance(output_spec, dtensor.DTensorSpec) # local propagation local_results = tp_convolution( - op_call, tuple(op_info.local_args), op_info.local_kwargs + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + output_spec.dim_map, ) - return dtensor.DTensor._op_dispatcher.wrap( - local_results, output_sharding.output_spec - ) + return dtensor.DTensor._op_dispatcher.wrap(local_results, output_spec) def convolution_backward_handler( @@ -270,10 +278,14 @@ def convolution_backward_handler( dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) output_sharding = op_info.output_sharding assert output_sharding is not None, "output sharding should not be None" + assert isinstance(op_info.flat_args_schema[0], dtensor.DTensorSpec) # local propagation local_results = tp_convolution_backward( - op_call, tuple(op_info.local_args), op_info.local_kwargs + op_call, + tuple(op_info.local_args), + op_info.local_kwargs, + op_info.flat_args_schema[0].dim_map, ) return dtensor.DTensor._op_dispatcher.wrap( From 50af6f339338e962a65bfae582eb3d40491b6a4f Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Mon, 10 Nov 2025 05:25:31 +0000 Subject: [PATCH 276/651] [MPS] erfinv for sparse mps (#166711) Should be merged after #166708 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166711 Approved by: https://github.com/Skylion007, https://github.com/malfet --- aten/src/ATen/native/native_functions.yaml | 6 +++--- torch/testing/_internal/common_methods_invocations.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 633d66f669b65..b2f469b8ceb5b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9832,7 +9832,7 @@ structured_delegate: erfinv.out variants: method, function dispatch: - SparseCPU, SparseCUDA: erfinv_sparse + SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr tags: pointwise @@ -9841,7 +9841,7 @@ structured_delegate: erfinv.out variants: method dispatch: - SparseCPU, SparseCUDA: erfinv_sparse_ + SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_ SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_ tags: pointwise @@ -9851,7 +9851,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS: erfinv_out - SparseCPU, SparseCUDA: erfinv_sparse_out + SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out tags: pointwise diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0413c9bf6b6e0..ecd2235b1445f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20320,6 +20320,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): torch.float32: 1e-4}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16), + supports_sparse=True, supports_sparse_csr=True, supports_sparse_csc=True, supports_sparse_bsr=True, From 47db55258b047380157f1e129edc6bf5326fce98 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Mon, 10 Nov 2025 05:27:46 +0000 Subject: [PATCH 277/651] [MPS] sparse sparse mm (#167013) Sparse sparse mm op implementation Pull Request resolved: https://github.com/pytorch/pytorch/pull/167013 Approved by: https://github.com/malfet --- aten/src/ATen/native/native_functions.yaml | 1 + .../native/sparse/mps/SparseMPSTensorMath.mm | 113 ++++++++++++++++++ test/test_sparse.py | 5 +- 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b2f469b8ceb5b..e6b96c0b12240 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4292,6 +4292,7 @@ dispatch: SparseCPU: sparse_sparse_matmul_cpu SparseCUDA: sparse_sparse_matmul_cuda + SparseMPS: sparse_sparse_matmul_mps autogen: _sparse_sparse_matmul.out - func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) diff --git a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm index 5dbee4e38af7b..d0e2ab0ffd799 100644 --- a/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm +++ b/aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm @@ -10,6 +10,10 @@ #include #else #include +#include +#include +#include +#include #include #include #include @@ -888,5 +892,114 @@ static void sparse_mask_intersection_out_mps_kernel( /*coalesce_mask=*/false); } +Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) { + TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(), + "sparse_sparse_matmul_mps: both inputs must be sparse COO tensors"); + TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(), + "sparse_sparse_matmul_mps: both inputs must be on MPS device"); + TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2, + "sparse_sparse_matmul_mps: both inputs must be 2D matrices"); + TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0, + "sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)"); + TORCH_CHECK(mat1_.size(1) == mat2_.size(0), + "mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")"); + TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(), + "sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(), + " does not match mat2 dtype ", mat2_.scalar_type()); + + const auto device = mat1_.device(); + + auto A = mat1_.coalesce(); + auto B = mat2_.coalesce(); + + const auto I = A.size(0); + const auto K = A.size(1); + const auto N = B.size(1); + + const auto nnzA = A._nnz(); + const auto nnzB = B._nnz(); + + // Early empty result, return an empty, coalesced tensor + if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) { + auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong)); + auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type())); + auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options()); + out._coalesced_(true); + return out; + } + + const auto computeDtype = at::result_type(mat1_, mat2_); + + auto A_idx = A._indices().contiguous(); + auto A_val = A._values().to(computeDtype).contiguous(); + auto A_i = A_idx.select(0, 0).contiguous(); + auto A_k = A_idx.select(0, 1).contiguous(); + + auto B_idx = B._indices().contiguous(); + auto B_val = B._values().to(computeDtype).contiguous(); + auto B_k = B_idx.select(0, 0).contiguous(); + auto B_j = B_idx.select(0, 1).contiguous(); + + // csr-style row pointers for B by k (the shared dimension) + Tensor row_ptr_B; + { + auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong)); + row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong)); + build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B); + } + + auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K); + auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K); + auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo); + + auto counts = deg_B.index_select(0, A_k); + + const int64_t P = counts.sum().item(); + if (P == 0) { + auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong)); + auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type())); + auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options()); + out._coalesced_(true); + return out; + } + + auto group_ids = repeat_interleave_mps(counts); + + // exclusive cumsum of counts + auto offsets = cumsum(counts, /*dim=*/0).sub(counts); + auto offsets_gather = offsets.index_select(0, group_ids); + auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather); + + // Map each output element to its source B row and position + auto k_per_out = A_k.index_select(0, group_ids); + auto start_in_B = row_ptr_B.index_select(0, k_per_out); + auto seg_index = start_in_B.add(within); + + // Assemble candidate coo pairs and values + auto i_out = A_i.index_select(0, group_ids).contiguous(); + auto j_out = B_j.index_select(0, seg_index).contiguous(); + auto vA_out = A_val.index_select(0, group_ids).contiguous(); + auto vB_out = B_val.index_select(0, seg_index).contiguous(); + auto v_out = vA_out.mul(vB_out); + + // build (2, P) indices + auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous(); + out_indices.select(0, 0).copy_(i_out); + out_indices.select(0, 1).copy_(j_out); + + auto result = _sparse_coo_tensor_unsafe( + out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype)); + + result = result.coalesce(); + + if (result.scalar_type() != mat1_.scalar_type()) { + auto cast_vals = result._values().to(mat1_.scalar_type()); + auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options()); + out._coalesced_(true); + return out; + } + return result; +} + REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel); } // namespace at::native \ No newline at end of file diff --git a/test/test_sparse.py b/test/test_sparse.py index 11e1629e374ba..76d7814137d60 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3728,7 +3728,6 @@ def test_log_softmax_float(self, device, dtype): @coalescedonoff @dtypes(*floating_and_complex_types()) @dtypesIfMPS(*all_mps_types()) - @expectedFailureMPS @dtypesIfCUDA(*floating_types_and(*[torch.half] if SM53OrLater and not TEST_WITH_ROCM else [], *[torch.bfloat16] if SM80OrLater and not TEST_WITH_ROCM else [], torch.complex64, @@ -3825,9 +3824,9 @@ def fn(sparse_dims, nnz, shape_a, shape_b): def different_dtypes(): a, i_a, v_a = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced) b, i_b, v_b = self._gen_sparse(2, 10, [2, 2], dtype, device, coalesced) - r2 = torch.sparse.mm(a.to(torch.float64), a.to(torch.float32)) + r2 = torch.sparse.mm(a.to(torch.float32), a.to(torch.float16)) - self.assertRaisesRegex(RuntimeError, 'mat1 dtype Double does not match mat2 dtype Float', different_dtypes) + self.assertRaisesRegex(RuntimeError, 'mat1 dtype Float does not match mat2 dtype Half', different_dtypes) def test_backward_noncontiguous(): # Sparse.mm backward used to wrong with non-contiguous grads, From 3cfbf98ea9d937d23f3700168b22706c957308ce Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Fri, 7 Nov 2025 17:11:32 +0000 Subject: [PATCH 278/651] [xpu][feature] Add XPU support on torch.accelerator.get_memory_info (#162564) # Motivation Support XPU for `torch.accelerator.get_memory_info`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162564 Approved by: https://github.com/albanD ghstack dependencies: #156812 --- c10/xpu/XPUCachingAllocator.cpp | 40 +++++++++++++++++++++++---------- test/test_accelerator.py | 6 +++++ torch/csrc/xpu/Module.cpp | 19 ++-------------- torch/xpu/memory.py | 1 + 4 files changed, 37 insertions(+), 29 deletions(-) diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 8c0eb7e18dcd2..ba748449b29e3 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -926,15 +926,14 @@ class DeviceCachingAllocator { (release_cached_blocks() && alloc_block(params, true)); } if (!block_found) { - c10::xpu::DeviceProp device_prop; - c10::xpu::get_device_properties(&device_prop, device); - auto device_total = device_prop.global_mem_size; + const auto& raw_device = c10::xpu::get_raw_device(device); + const auto device_total = + raw_device.get_info(); // Estimate the available device memory when the SYCL runtime does not // support the corresponding aspect (ext_intel_free_memory). - size_t device_free = device_prop.global_mem_size - + size_t device_free = device_total - stats.reserved_bytes[static_cast(StatType::AGGREGATE)] .current; - auto& raw_device = c10::xpu::get_raw_device(device); // TODO: Remove the aspect check once the SYCL runtime bug is fixed on // affected devices. if (raw_device.has(sycl::aspect::ext_intel_free_memory)) { @@ -1052,21 +1051,37 @@ class DeviceCachingAllocator { } } + std::pair getMemoryInfo() { + const auto& device = c10::xpu::get_raw_device(device_index); + const size_t total = device.get_info(); + TORCH_CHECK( + device.has(sycl::aspect::ext_intel_free_memory), + "The device (", + device.get_info(), + ") doesn't support querying the available free memory. ", + "You can file an issue at https://github.com/pytorch/pytorch/issues ", + "to help us prioritize its implementation."); + const size_t free = + device.get_info(); + return {free, total}; + } + double getMemoryFraction() { if (!set_fraction) { return 1.0; } - c10::xpu::DeviceProp device_prop; - c10::xpu::get_device_properties(&device_prop, device_index); + const auto device_total = + xpu::get_raw_device(device_index) + .get_info(); return static_cast(allowed_memory_maximum) / - static_cast(device_prop.global_mem_size); + static_cast(device_total); } void setMemoryFraction(double fraction) { - c10::xpu::DeviceProp device_prop; - c10::xpu::get_device_properties(&device_prop, device_index); - auto device_total = device_prop.global_mem_size; + const auto device_total = + xpu::get_raw_device(device_index) + .get_info(); allowed_memory_maximum = static_cast(fraction * device_total); set_fraction = true; } @@ -1241,7 +1256,8 @@ class XPUAllocator : public DeviceAllocator { } std::pair getMemoryInfo(DeviceIndex device) override { - TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented yet."); + assertValidDevice(device); + return device_allocators[device]->getMemoryInfo(); } double getMemoryFraction(DeviceIndex device) { diff --git a/test/test_accelerator.py b/test/test_accelerator.py index d44c8b0d350c9..7daebc01adfe9 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -239,6 +239,12 @@ def test_memory_stats(self): self.assertEqual(torch.accelerator.max_memory_allocated(), prev_max_allocated) self.assertEqual(torch.accelerator.max_memory_reserved(), prev_max_reserved) + @unittest.skipIf(TEST_MPS, "MPS doesn't support torch.accelerator memory API!") + def test_get_memory_info(self): + free_bytes, total_bytes = torch.accelerator.get_memory_info() + self.assertGreaterEqual(free_bytes, 0) + self.assertGreaterEqual(total_bytes, 0) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 44d11a5bd9741..b3d1dd929a216 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -386,23 +386,8 @@ static void bindGetDeviceProperties(PyObject* module) { static void initXpuMethodBindings(PyObject* module) { auto m = py::handle(module).cast(); m.def("_xpu_getMemoryInfo", [](c10::DeviceIndex device_index) { -#if SYCL_COMPILER_VERSION >= 20250000 - auto total = at::xpu::getDeviceProperties(device_index)->global_mem_size; - auto& device = c10::xpu::get_raw_device(device_index); - TORCH_CHECK( - device.has(sycl::aspect::ext_intel_free_memory), - "The device (", - at::xpu::getDeviceProperties(device_index)->name, - ") doesn't support querying the available free memory. ", - "You can file an issue at https://github.com/pytorch/pytorch/issues ", - "to help us prioritize its implementation."); - auto free = device.get_info(); - return std::make_tuple(free, total); -#else - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "torch.xpu.mem_get_info requires PyTorch to be built with SYCL compiler version 2025.0.0 or newer."); -#endif + py::gil_scoped_release no_gil; + return at::getDeviceAllocator(at::kXPU)->getMemoryInfo(device_index); }); m.def( "_xpu_getStreamFromExternal", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py index 069d93cefa9b6..3a9c7d7c83ee4 100644 --- a/torch/xpu/memory.py +++ b/torch/xpu/memory.py @@ -190,6 +190,7 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]: int: the memory available on the device in units of bytes. int: the total memory on the device in units of bytes """ + _lazy_init() device = _get_device_index(device, optional=True) return torch._C._xpu_getMemoryInfo(device) From 52e744d68a6ea5319d0e8f7bc0e3a2f3cc599ae1 Mon Sep 17 00:00:00 2001 From: zpcore Date: Sun, 9 Nov 2025 23:17:16 -0800 Subject: [PATCH 279/651] [DTensor] Support convert StridedShard to shard order and vice versa (#166740) We plan to use `StridedShard` to express `shard_order`. This PR adds the function to support the conversion between `StridedShard` and `shard_order`. I moved some test related function into torch/testing/_internal/common_utils.py. We may only care about **_dtensor_spec.py** and **test_utils.py** in this PR for the review. ### How to convert shard order to StridedShard: Considering the example: - placements = $[x_0, x_1, x_2, x_3, x_4]$, all $x_?$ are shard on the same tensor dim. Let's see how the shard order will impact the split_factor (sf). We loop from right to left in the placements to construct the split_factor by assuming different shard order. Starting from $x_4$, this should be a normal shard. Then $x_3$. There are two possibilities, $x_3$'s order can be before $x_4$. If so, $x_3$'s sf=1, because $x_3$ is before $x_4$ in the placements. Else $x_3$'s order is after $x_4$, then the $x_3$'s sf should be the mesh dim size of $x_4$, which is $T(x_4)$: image We can use this method to decide on the split factor for $x_2$, $x_1$ and so on. ### How to convert StridedShard to shard order: This follows the same method above. We check all possible paths and use the real split_factor to see which path matchs the split_factor. If no such matches, the StridedShard is unable to be converted to shard order. --- Pull Request resolved: https://github.com/pytorch/pytorch/pull/166740 Approved by: https://github.com/ezyang --- test/distributed/tensor/test_redistribute.py | 170 +++------------- test/distributed/tensor/test_utils.py | 61 ++++++ torch/distributed/tensor/_dtensor_spec.py | 181 ++++++++++++++++++ .../distributed/_tensor/common_dtensor.py | 125 ++++++++++++ 4 files changed, 390 insertions(+), 147 deletions(-) diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index 23593462f0a29..381660e47927d 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -2,7 +2,6 @@ # Owner(s): ["oncall: distributed"] import contextlib -import copy import itertools import unittest @@ -22,9 +21,8 @@ ) from torch.distributed.tensor._collective_utils import shard_dim_alltoall from torch.distributed.tensor._dtensor_spec import ShardOrderEntry -from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.debug import CommDebugMode -from torch.distributed.tensor.placement_types import _StridedShard +from torch.distributed.tensor.placement_types import _StridedShard, MaskPartial from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -35,7 +33,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import ( create_local_tensor_test_class, DTensorTestBase, + generate_shard_orders, + make_full_tensor, map_local_tensor_for_rank, + patched_distribute_tensor as _distribute_tensor, + redistribute, with_comms, ) from torch.utils._debug_mode import DebugMode @@ -785,88 +787,6 @@ def _extract_redistribute_trace_from_debug_mode(self, s: str) -> str: else: return "" - # TODO(zpcore): remove once the native redistribute supports shard_order arg - def redistribute( - self, - dtensor_input, - device_mesh, - placements, - shard_order, - use_graph_based_transform=True, - ): - """ - wrapper function to support shard_order for redistribution - This is a simpler version of Redistribute, only considers the forward. - """ - if placements is None: - placements = self._shard_order_to_placement(shard_order, device_mesh) - placements = tuple(placements) - old_spec = dtensor_input._spec - new_spec = copy.deepcopy(old_spec) - new_spec.placements = placements - if shard_order is not None: - new_spec.shard_order = shard_order - else: - new_spec.shard_order = () - if old_spec == new_spec: - return dtensor_input - dtensor_input = DTensor.from_local( - redistribute_local_tensor( - dtensor_input.to_local(), - old_spec, - new_spec, - use_graph_based_transform=use_graph_based_transform, - ), - device_mesh, - ) - dtensor_input._spec = copy.deepcopy(new_spec) - return dtensor_input # returns DTensor - - # TODO(zpcore): remove once the native distribute_tensor supports - # shard_order arg - def distribute_tensor( - self, - input_tensor, - device_mesh, - placements, - shard_order, - use_graph_based_transform=True, - ): - """wrapper function to support shard_order for tensor distribution""" - if placements is None: - placements = self._shard_order_to_placement(shard_order, device_mesh) - placements = tuple(placements) - tensor_dt = distribute_tensor(input_tensor, device_mesh, placements) - # fix the shard order - return self.redistribute( - tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform - ) - - # TODO(zpcore): remove once the native redistribute supports shard_order arg - def full_tensor(self, dtensor_input): - """wrapper function to support DTensor.full_tensor""" - return self.redistribute( - dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=() - ).to_local() - - def _shard_order_to_placement(self, shard_order, mesh): - """convert shard_order to placement with only Replicate() and Shard()""" - placements = [Replicate() for _ in range(mesh.ndim)] - if shard_order is not None: - for entry in shard_order: - tensor_dim = entry.tensor_dim - mesh_dims = entry.mesh_dims - for mesh_dim in mesh_dims: - placements[mesh_dim] = Shard(tensor_dim) - return tuple(placements) - - def _convert_shard_order_dict_to_ShardOrder(self, shard_order): - """Convert shard_order dict to ShardOrder""" - return tuple( - ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims)) - for tensor_dim, mesh_dims in shard_order.items() - ) - @with_comms def test_ordered_redistribute(self): """Test ordered redistribution with various sharding syntaxes""" @@ -927,13 +847,11 @@ def test_ordered_redistribute(self): for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate( sharding_src_dst_pairs_with_expected_trace ): - sharded_dt = self.distribute_tensor( + sharded_dt = _distribute_tensor( input_data.clone(), mesh, src_placement, shard_order=src_order ) with DebugMode(record_torchfunction=False) as debug_mode: - sharded_dt = self.redistribute( - sharded_dt, mesh, dst_placement, dst_order - ) + sharded_dt = redistribute(sharded_dt, mesh, dst_placement, dst_order) trace_str = self._extract_redistribute_trace_from_debug_mode( debug_mode.debug_string() ) @@ -957,49 +875,11 @@ def test_ordered_redistribute(self): trace_str, """S(0)[0]S(0)[1]R->S(0)S(1)R->RS(1)R->RS(1)S(0)""", ) - expected_dt = self.distribute_tensor( + expected_dt = _distribute_tensor( input_data.clone(), mesh, dst_placement, shard_order=dst_order ) self.assertEqual(sharded_dt.to_local(), expected_dt.to_local()) - def generate_shard_orders(self, mesh, tensor_rank): - # Generate all possible sharding placement of tensor with rank - # `tensor_rank` over mesh. - def _split_list(lst: list, N: int): - def compositions(n, k): - if k == 1: - yield [n] - else: - for i in range(1, n - k + 2): - for tail in compositions(n - i, k - 1): - yield [i] + tail - - length = len(lst) - for comp in compositions(length, N): - result = [] - start = 0 - for size in comp: - result.append(lst[start : start + size]) - start += size - yield result - - all_mesh = list(range(mesh.ndim)) - all_device_order = list(itertools.permutations(all_mesh)) - for device_order in all_device_order: - # split on device orders, and assign each device order segment to a tensor dim - for num_split in range(1, mesh.ndim + 1): - for splitted_list in _split_list(list(range(mesh.ndim)), num_split): - for tensor_dims in itertools.combinations( - range(tensor_rank), len(splitted_list) - ): - shard_order = {} - assert len(tensor_dims) == len(splitted_list) - for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list): - shard_order[tensor_dim] = device_order[ - mesh_dims[0] : mesh_dims[-1] + 1 - ] - yield self._convert_shard_order_dict_to_ShardOrder(shard_order) - @with_comms def test_generate_shard_orders(self): """Check if `generate_shard_orders` generates unique sharding combinations""" @@ -1012,7 +892,7 @@ def test_generate_shard_orders(self): ] for test_input in test_inputs: all_combinations = [] - for shard_order in self.generate_shard_orders( + for shard_order in generate_shard_orders( test_input["mesh"], test_input["tensor_rank"] ): all_combinations.append(shard_order) # noqa: PERF402 @@ -1062,12 +942,12 @@ def test_ordered_distribute_all_combination(self): input_data = torch.randn(tensor_shape, device=self.device_type) tensor_rank = input_data.ndim with maybe_disable_local_tensor_mode(): - shard_orders = self.generate_shard_orders(mesh, tensor_rank) + shard_orders = generate_shard_orders(mesh, tensor_rank) for shard_order in shard_orders: - sharded_dt = self.distribute_tensor( + sharded_dt = _distribute_tensor( input_data.clone(), mesh, placements=None, shard_order=shard_order ) - self.assertEqual(self.full_tensor(sharded_dt), input_data) + self.assertEqual(make_full_tensor(sharded_dt), input_data) # 2. Verify the correctness of redistribution from DTensor to DTensor. # This test repeatedly redistributes a DTensor to various ordered @@ -1078,20 +958,20 @@ def test_ordered_distribute_all_combination(self): tensor_rank = input_data.ndim prev_sharded_dt = None with maybe_disable_local_tensor_mode(): - shard_orders = self.generate_shard_orders(mesh, tensor_rank) + shard_orders = generate_shard_orders(mesh, tensor_rank) for shard_order in shard_orders: if prev_sharded_dt is None: - prev_sharded_dt = self.distribute_tensor( + prev_sharded_dt = _distribute_tensor( input_data.clone(), mesh, placements=None, shard_order=shard_order, ) else: - sharded_dt = self.redistribute( + sharded_dt = redistribute( prev_sharded_dt, mesh, placements=None, shard_order=shard_order ) - self.assertEqual(self.full_tensor(sharded_dt), input_data) + self.assertEqual(make_full_tensor(sharded_dt), input_data) prev_sharded_dt = sharded_dt @with_comms @@ -1136,13 +1016,13 @@ def _is_valid_placement(placements, tensor_rank): local_tensor = torch.randn(shape, device=self.device_type) full_tensor = DTensor.from_local(local_tensor, mesh, placements) with maybe_disable_local_tensor_mode(): - shard_orders = self.generate_shard_orders(mesh, len(shape)) + shard_orders = generate_shard_orders(mesh, len(shape)) for shard_order in shard_orders: - sharded_dt = self.redistribute( + sharded_dt = redistribute( full_tensor, mesh, placements=None, shard_order=shard_order ) self.assertEqual( - self.full_tensor(sharded_dt), self.full_tensor(full_tensor) + make_full_tensor(sharded_dt), make_full_tensor(full_tensor) ) @unittest.skip( @@ -1152,24 +1032,20 @@ def _is_valid_placement(placements, tensor_rank): @with_comms def test_ordered_redistribute_for_special_placement(self): """Test ordered redistribution with special placement""" - from torch.distributed.tensor._ops._embedding_ops import _MaskPartial - torch.manual_seed(21) mesh = init_device_mesh(self.device_type, (8,)) input_data = torch.randn((8, 8), device=self.device_type) src_placement = [Shard(1)] tgt_placement = [ - (_MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),) + (MaskPartial(offset_shape=torch.Size([10, 20]), offset_dim=0),) ] - sharded_dt = self.distribute_tensor( + sharded_dt = _distribute_tensor( input_data.clone(), mesh, src_placement, shard_order=(ShardOrderEntry(tensor_dim=1, mesh_dims=(0,)),), ) - sharded_dt = self.redistribute( - sharded_dt, mesh, tgt_placement, shard_order=None - ) + sharded_dt = redistribute(sharded_dt, mesh, tgt_placement, shard_order=None) @with_comms def test_shard_order_same_data_as_strided_shard(self): @@ -1179,7 +1055,7 @@ def test_shard_order_same_data_as_strided_shard(self): strided_placement = [_StridedShard(-2, split_factor=2), Shard(-2)] x_strided_dt = distribute_tensor(x, device_mesh, strided_placement) # specify right-to-left order use ordered shard - x_ordered_dt = self.distribute_tensor( + x_ordered_dt = _distribute_tensor( x, device_mesh, placements=[Shard(0), Shard(0)], diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index 09a6ca817a75b..d129765506feb 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -34,6 +34,10 @@ from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, + generate_shard_orders, + LocalDTensorTestBase, + patched_distribute_tensor as _distribute_tensor, + shard_order_to_placement, with_comms, ) @@ -774,6 +778,63 @@ def test_2d_mesh_uneven_strided_shard(self): self.assertEqual(dtensor.full_tensor(), tensor) +class Test_StridedShard_with_shard_order(LocalDTensorTestBase): + @property + def world_size(self) -> int: + return 32 + + @with_comms + def test_StridedShard_to_shard_order(self): + with LocalTensorMode(ranks=self.world_size): + mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(2, 2, 2, 2, 2)) + shard_iter = generate_shard_orders(mesh, 3) + # It takes ~4.8h to complete total 2520 shard order combinations here + # using LocalTensor. So we only randomly pick 25 shard orders to test. + all_shard_order = list(shard_iter) + import random + + random.seed(42) + shard_order_choices = random.sample( + all_shard_order, min(25, len(all_shard_order)) + ) + + x = torch.randn(32, 32, 32) + for shard_order in shard_order_choices: + a = _distribute_tensor(x, mesh, None, shard_order) + + placement_without_stridedshard = shard_order_to_placement( + shard_order, mesh + ) + placements_with_stridedshard = ( + DTensorSpec._convert_shard_order_to_StridedShard( + shard_order, placement_without_stridedshard, mesh + ) + ) + b = distribute_tensor(x, mesh, placements_with_stridedshard) + shard_order_from_stridedshard = ( + DTensorSpec._maybe_convert_StridedShard_to_shard_order( + placements_with_stridedshard, mesh + ) + ) + self.assertEqual(shard_order, shard_order_from_stridedshard) + self.assertEqual(a.to_local(), b.to_local()) + + @with_comms + def test_StridedShard_not_convertible_to_shard_order(self): + with LocalTensorMode(ranks=self.world_size): + mesh = DeviceMesh("cpu", torch.arange(self.world_size).view(4, 8)) + unconvertible_placements_list = [ + [_StridedShard(0, split_factor=2), _StridedShard(1, split_factor=2)], + [_StridedShard(0, split_factor=2), Shard(1)], + [_StridedShard(1, split_factor=16), Shard(1)], + ] + for placements in unconvertible_placements_list: + shard_order = DTensorSpec._maybe_convert_StridedShard_to_shard_order( + tuple(placements), mesh + ) + self.assertIsNone(shard_order) + + class Test2DStridedLocalShard(DTensorTestBase): @property def world_size(self): diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 5e7d7b3c842d2..ca51cdf70c058 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -1,4 +1,5 @@ import itertools +import math from collections import defaultdict from dataclasses import dataclass from typing import Any, cast, NamedTuple, Optional @@ -7,6 +8,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( _StridedShard, + MaskPartial, Partial, Placement, Replicate, @@ -127,6 +129,185 @@ def compute_default_shard_order( ) return default_shard_order + @staticmethod + def _convert_shard_order_to_StridedShard( + shard_order: ShardOrder, placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> tuple[Placement, ...]: + """ + Convert ShardOrder to placements with _StridedShard. + + This function converts a ShardOrder specification into a tuple of Placement objects, + using _StridedShard when a tensor dimension is sharded across multiple mesh dimensions + in a non-default order. The split_factor of each _StridedShard is determined by the + product of mesh dimension sizes that appear earlier in the shard order but later in + the placement tuple. + + Args: + shard_order: ShardOrder specification indicating which tensor dimensions are + sharded on which mesh dimensions and in what execution order. + placements: Tuple of Placement objects that does not contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + Updated tuple of Placement objects with Shard or _StridedShard placements. + + Algorithm: + For each ShardOrderEntry in shard_order: + - For each mesh dimension in the entry's mesh_dims (in order): + - Calculate split_factor as the product of mesh sizes for all mesh dimensions + that appear: + 1. Earlier in the shard order (lower index in mesh_dims), and + 2. Later in the placement tuple (higher mesh dimension index) + - If split_factor == 1: use normal Shard + - Otherwise: use _StridedShard with the calculated split_factor + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # Tensor dimension 0 sharded on mesh dims [2, 0, 1] in that order + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> shard_order = (ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),) + >>> placements = (Shard(0), Shard(0), Shard(0)) + >>> # For mesh_dim=2 (index 0 in mesh_dims): no earlier dims, split_factor=1 + >>> # -> placements[2] = Shard(0) + >>> # For mesh_dim=0 (index 1 in mesh_dims): mesh_dim=2 is earlier and has index 2>0 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[0] = _StridedShard(0, split_factor=2) + >>> # For mesh_dim=1 (index 2 in mesh_dims): mesh_dim=2 is earlier and has index 2>1 + >>> # -> split_factor = mesh.size(2) = 2 + >>> # -> placements[1] = _StridedShard(0, split_factor=2) + >>> # Result: (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + """ + placements_list = list(placements) + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + for idx in range(len(mesh_dims)): + # TODO(zpcore): split_factor from `view` and `shard order` + # should be able to be multiplied into one. Need to loosen the + # condition here. + mesh_dim = mesh_dims[idx] + if type(placements[mesh_dim]) is not Shard: + raise ValueError( + f"Only Shard placement can be converted to _StridedShard, " + f"found {placements[mesh_dim]} in {placements=}." + ) + split_factor = math.prod( + mesh.size(i) for i in mesh_dims[:idx] if i > mesh_dim + ) + if split_factor == 1: + # use normal Shard + placements_list[mesh_dim] = Shard(tensor_dim) + else: + placements_list[mesh_dim] = _StridedShard( + tensor_dim, split_factor=split_factor + ) + return tuple(placements_list) + + @staticmethod + def _maybe_convert_StridedShard_to_shard_order( + placements: tuple[Placement, ...], mesh: DeviceMesh + ) -> Optional[ShardOrder]: + """ + Try to convert _StridedShard placements to ShardOrder. + + This is the inverse of `_convert_shard_order_to_StridedShard`. It reconstructs the shard + order by examining the split_factor of each _StridedShard and determining its position + in the execution order. If the _StridedShard configuration cannot be represented as a + valid ShardOrder (i.e., there's no shard order that produces the observed split_factors), + this function returns None. + + Args: + placements: Tuple of Placement objects that may contain _StridedShard. + mesh: DeviceMesh containing the size information for each mesh dimension. + + Returns: + ShardOrder if conversion is possible, None otherwise. For placements without + _StridedShard, returns the default shard order. + + Algorithm: + 1. If no _StridedShard in placements, return default shard order + 2. Create an empty list for each tensor dimension to represent mesh dim ordering + 3. Iterate through placements in reverse order (right to left): + - For each Shard/_StridedShard on a tensor dimension: + - Extract its split_factor (1 for Shard, split_factor for _StridedShard) + - Find the position in mesh_dims_order where accumulated_sf equals split_factor + - accumulated_sf is the product of mesh sizes of mesh dimensions that appear + earlier in mesh_dims_order (lower indices) + - Insert mesh_dim at the found position + 4. If no valid position found for any split_factor, return None (unable to convert) + 5. Construct ShardOrderEntry for each tensor dimension from mesh_dims_order + + Example: + >>> # xdoctest: +SKIP("Requires DeviceMesh") + >>> # mesh = DeviceMesh([4, 3, 2]) # sizes: mesh[0]=4, mesh[1]=3, mesh[2]=2 + >>> # placements = (_StridedShard(0, sf=2), _StridedShard(0, sf=2), Shard(0)) + >>> # Process tensor_dim=0 from right to left: + >>> # - mesh_dim=2: Shard(0) with sf=1 + >>> # Try position 0: accumulated_sf=1, matches! Insert at position 0 + >>> # Current mesh_dims_order order: [2] + >>> # - mesh_dim=1: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Current mesh_dims_order order: [2, 1] + >>> # - mesh_dim=0: _StridedShard(0, sf=2) with sf=2 + >>> # Try position 0: accumulated_sf=1, no match + >>> # Try position 1: accumulated_sf=1*mesh.size(2)=2, matches! Insert at position 1 + >>> # Final mesh_dims_order order: [2, 0, 1] + >>> # Result: ShardOrder((ShardOrderEntry(tensor_dim=0, mesh_dims=(2, 0, 1)),)) + >>> # This means: first shard on mesh_dim=2, then mesh_dim=0, then mesh_dim=1 + + Note: + This function validates that _StridedShard can be represented as a ShardOrder. + Not all _StridedShard configurations are valid - the split_factor must match + the product of mesh sizes in some execution order. + """ + if not any(isinstance(p, _StridedShard) for p in placements): + return DTensorSpec.compute_default_shard_order(placements) + max_tensor_dim = ( + max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1 + ) + shard_order = [] + + tensor_dim_to_mesh_dims_order: list[list[int]] = [ + [] for i in range(max_tensor_dim) + ] + for mesh_dim in reversed(range(len(placements))): + cur_placement = placements[mesh_dim] + # _StridedShard may not be a subclass of Shard in the future, so write in this way: + if isinstance(cur_placement, Shard | _StridedShard): + tensor_dim = cur_placement.dim + mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim] + cur_sf = 1 + if isinstance(cur_placement, _StridedShard): + cur_sf = cur_placement.split_factor + accumulated_sf = 1 + find_order = False + for i in range(len(mesh_dims_order) + 1): + if accumulated_sf == cur_sf: + mesh_dims_order.insert(i, mesh_dim) + find_order = True + break + if i < len(mesh_dims_order): + accumulated_sf *= mesh.size(mesh_dims_order[i]) + if not find_order: + # _StridedShard is not convertible to ShardOrder + return None + else: + if not isinstance(cur_placement, Replicate | Partial | MaskPartial): + raise ValueError( + f"Unsupported placement type {type(cur_placement)} encountered in " + f"{placements}; expected Replicate, Partial, or MaskPartial." + ) + for tensor_dim in range(max_tensor_dim): + if len(tensor_dim_to_mesh_dims_order[tensor_dim]) > 0: + shard_order.append( + ShardOrderEntry( + tensor_dim=tensor_dim, + mesh_dims=tuple(tensor_dim_to_mesh_dims_order[tensor_dim]), + ) + ) + return tuple(shard_order) + def _verify_shard_order(self, shard_order: ShardOrder) -> None: """Verify that the shard_order is valid and matches the placements.""" total_shard = 0 diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index f4afca4bd1803..6ce7d4b2ca507 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -3,6 +3,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import contextlib +import copy import functools import itertools import sys @@ -32,6 +33,8 @@ Replicate, Shard, ) +from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -818,3 +821,125 @@ def map_local_for_rank(rank, func): def reduce_local_int(val, func): return func(val.node._local_ints) + + +def _convert_shard_order_dict_to_ShardOrder(shard_order): + """Convert shard_order dict to ShardOrder""" + return tuple( + ShardOrderEntry(tensor_dim=tensor_dim, mesh_dims=tuple(mesh_dims)) + for tensor_dim, mesh_dims in shard_order.items() + ) + + +# TODO(zpcore): remove once the native redistribute supports shard_order arg +def redistribute( + dtensor_input, + device_mesh, + placements, + shard_order, + use_graph_based_transform=True, +): + """ + wrapper function to support shard_order for redistribution + This is a simpler version of Redistribute, only considers the forward. + """ + if placements is None: + placements = shard_order_to_placement(shard_order, device_mesh) + placements = tuple(placements) + old_spec = dtensor_input._spec + new_spec = copy.deepcopy(old_spec) + new_spec.placements = placements + if shard_order is not None: + new_spec.shard_order = shard_order + else: + new_spec.shard_order = () + if old_spec == new_spec: + return dtensor_input + dtensor_input = DTensor.from_local( + redistribute_local_tensor( + dtensor_input.to_local(), + old_spec, + new_spec, + use_graph_based_transform=use_graph_based_transform, + ), + device_mesh, + ) + dtensor_input._spec = copy.deepcopy(new_spec) + return dtensor_input # returns DTensor + + +# TODO(zpcore): remove once the native distribute_tensor supports +# shard_order arg +def patched_distribute_tensor( + input_tensor, + device_mesh, + placements, + shard_order, + use_graph_based_transform=True, +): + """wrapper function to support shard_order for tensor distribution""" + if placements is None: + placements = shard_order_to_placement(shard_order, device_mesh) + placements = tuple(placements) + tensor_dt = distribute_tensor(input_tensor, device_mesh, placements) + # fix the shard order + return redistribute( + tensor_dt, device_mesh, placements, shard_order, use_graph_based_transform + ) + + +# TODO(zpcore): remove once the native redistribute supports shard_order arg +def make_full_tensor(dtensor_input): + """wrapper function to support DTensor.full_tensor""" + return redistribute( + dtensor_input, dtensor_input.device_mesh, placements=None, shard_order=() + ).to_local() + + +def shard_order_to_placement(shard_order, mesh): + """convert shard_order to placement with only Replicate() and Shard()""" + placements: list[Any] = [Replicate() for _ in range(mesh.ndim)] + if shard_order is not None: + for entry in shard_order: + tensor_dim = entry.tensor_dim + mesh_dims = entry.mesh_dims + for mesh_dim in mesh_dims: + placements[mesh_dim] = Shard(tensor_dim) + return tuple(placements) + + +def generate_shard_orders(mesh, tensor_rank): + # Generate all possible sharding placement of tensor with rank + # `tensor_rank` over mesh. + def _split_list(lst: list, N: int): + def compositions(n: int, k: int): + # yields lists of length k, positive ints summing to n + for cuts in itertools.combinations(range(1, n), k - 1): + # add 0 and n as sentinels, then take consecutive differences + yield [b - a for a, b in itertools.pairwise((0, *cuts, n))] + + length = len(lst) + for comp in compositions(length, N): + result = [] + start = 0 + for size in comp: + result.append(lst[start : start + size]) + start += size + yield result + + all_mesh = list(range(mesh.ndim)) + all_device_order = list(itertools.permutations(all_mesh)) + for device_order in all_device_order: + # split on device orders, and assign each device order segment to a tensor dim + for num_split in range(1, mesh.ndim + 1): + for splitted_list in _split_list(list(range(mesh.ndim)), num_split): + for tensor_dims in itertools.combinations( + range(tensor_rank), len(splitted_list) + ): + shard_order = {} + assert len(tensor_dims) == len(splitted_list) + for tensor_dim, mesh_dims in zip(tensor_dims, splitted_list): + shard_order[tensor_dim] = device_order[ + mesh_dims[0] : mesh_dims[-1] + 1 + ] + yield _convert_shard_order_dict_to_ShardOrder(shard_order) From 74aec83841385e244dd3de5a9fad875aec64f003 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 10 Nov 2025 12:02:58 +0000 Subject: [PATCH 280/651] [xla hash update] update the pinned xla hash (#167452) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned xla hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167452 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/xla.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 01f0673fcf802..191c21631f662 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9 +e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a From c28475db7c6d1de5384309f02c508099c3389d2b Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 10 Nov 2025 12:39:23 +0000 Subject: [PATCH 281/651] Update slow tests (#166844) This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml). Update the list of slow tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166844 Approved by: https://github.com/pytorchbot --- test/slow_tests.json | 477 +++++++++++++++++++++---------------------- 1 file changed, 234 insertions(+), 243 deletions(-) diff --git a/test/slow_tests.json b/test/slow_tests.json index fe23e854cc8e8..c027d3d1d0901 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,245 +1,236 @@ { - "EndToEndLSTM (__main__.RNNTest)": 207.89400227864584, - "MultiheadAttention (__main__.ModulesTest)": 141.1396687825521, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 214.02366638183594, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 77.26125049591064, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.37000020345052, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 69.25722334120009, - "test_after_aot_gpu_runtime_error (__main__.MinifierIsolateTests)": 65.84466807047527, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 178.41399637858072, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.55014337812151, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 122.18047623407273, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 192.6405719575428, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.27904801141648, - "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.906999588012695, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.244998931884766, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 150.04100036621094, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 191.85050201416016, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 111.9276631673177, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.31450271606445, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 125.24066416422527, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.47783279418945, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.46250025431316, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1031.0534973144531, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 239.67400105794272, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 495.0447726779514, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 490.18524169921875, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 144.06477737426758, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 342.20416259765625, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 62.01366678873698, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.07200050354004, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 73.9221674601237, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.0122528076172, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 144.97249857584634, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 303.20537185668945, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 386.0518798828125, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 291.2442270914714, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 95.87866719563802, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 98.38716634114583, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 69.08016649881999, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.88233311971028, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 104.17599995930989, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 97.41800308227539, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 474.6719970703125, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 440.4375, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 293.3983332316081, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 238.7328338623047, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1218.4906717936199, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.73516782124837, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1156.0123494466145, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.13916714986165, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.90450032552083, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.42100016276042, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.98883310953777, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.34433364868164, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.38016573588053, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.52783330281575, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 111.06333287556966, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.19833374023438, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 113.10083134969075, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.23766644795736, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 70.18666712443034, - "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.61399841308594, - "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 67.7816670735677, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 121.6183344523112, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.30266698201497, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 130.8143310546875, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 127.27633412679036, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 303.55183664957684, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 234.41216532389322, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 85.3436673482259, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.9688326517741, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.55149968465169, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.37966791788737, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 129.88233184814453, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 129.4015007019043, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1282.3826497395833, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1270.64599609375, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1297.9046630859375, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 545.2034962972006, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 572.5616760253906, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 64.40316645304362, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 64.68383344014485, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.48333422342936, - "test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 61.959999084472656, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 105.79100036621094, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 122.34666570027669, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 68.7205015818278, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.2183329264323, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.86883227030437, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 77.48183314005534, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 79.1564998626709, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 160.41250228881836, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 79.10633341471355, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 60.106833140055336, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 221.3586196899414, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 504.3203754425049, - "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 78.03233337402344, - "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 152.302001953125, - "test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 152.99433390299478, - "test_conv_bn_fuse_cpu (__main__.CpuTests)": 96.25399971008301, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 75.70275068283081, - "test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 139.14399747674665, - "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 72.7847490310669, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 91.59966786702473, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 87.57833353678386, - "test_count_nonzero_all (__main__.TestBool)": 664.9986343383789, - "test_cp_flex_attention_document_mask (__main__.CPFlexAttentionTest)": 78.31500244140625, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 385.24249792099, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.70466740926106, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 685.0679931640625, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 86.26266733805339, - "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 292.93699645996094, - "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.84199905395508, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 69.56212568283081, - "test_fail_creation_ops.py (__main__.TestTyping)": 69.80560022989908, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.36666552225749, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 90.40366744995117, - "test_fuse_large_params_cpu (__main__.CpuTests)": 132.73199844360352, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 150.16662406921387, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 159.28499794006348, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 165.19283294677734, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 151.12366739908853, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.61699930826823, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.00600179036458, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.3759994506836, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.89249674479166, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 149.6598358154297, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 146.07766723632812, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 532.8139902750651, - "test_graph_partition_refcount_cuda (__main__.GPUTests)": 69.78400001525878, - "test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 267.04988850487604, - "test_graph_partition_refcount_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 273.54955800374347, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 195.84733072916666, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 326.0143330891927, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.96037435531616, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 77.44933319091797, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 126.81488884819879, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 118.70199839274089, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 129.20266723632812, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 97.18800099690755, - "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 130.3183339436849, - "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 140.43233235677084, - "test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 293.122774971856, - "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 63.835832277933754, - "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 106.77049922943115, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 100.89649963378906, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 140.07424926757812, - "test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 72.90299733479817, - "test_max_autotune_addmm_search_space_EXHAUSTIVE_dynamic_True (__main__.TestMaxAutotuneSubproc)": 82.62433369954427, - "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 87.51499938964844, - "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_True_use_aoti_True (__main__.TestCKBackend)": 71.22416591644287, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 424.50966389973956, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.14600626627603, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 358.88099161783856, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.58866712782118, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.68674945831299, - "test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 65.85794713936355, - "test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 103.6923344930013, - "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 187.6953328450521, - "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 370.27442932128906, - "test_proper_exit (__main__.TestDataLoader)": 227.83111148410373, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 227.1901126437717, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.52099990844727, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.50249862670898, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 92.52400207519531, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.75499725341797, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 107.40500259399414, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 83.80450057983398, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.46599833170573, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.65650177001953, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 83.4114990234375, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.47100067138672, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 108.55533345540364, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 89.23666381835938, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 105.13900375366211, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 100.14550018310547, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 107.33649826049805, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.08150100708008, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 97.59600067138672, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.82933553059895, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 114.43099721272786, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.40333302815755, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 567.2765197753906, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1032.5083312988281, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 852.7170003255209, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1361.954854329427, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 77.385498046875, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 265.0193354288737, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 115.31749725341797, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 245.27666727701822, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 71.75300216674805, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 141.8895009358724, - "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 71.15749994913737, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 90.59066772460938, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.73916625976562, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.65066655476888, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 99.21799850463867, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 90.86299896240234, - "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.57050196329753, - "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 69.65149958928426, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 78.13350168863933, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 76.85255601671007, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 333.04866282145184, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 146.96599833170572, - "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 160.4881100124783, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.10055626763238, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 117.38410907321506, - "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 710.2327779134115, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 1324.4399820963542, - "test_sort_stable_cuda (__main__.GPUTests)": 76.83109970092774, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 88.58433532714844, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 160.1271684964498, - "test_tensor_split (__main__.TestVmapOperators)": 79.18955569393519, - "test_terminate_handler_on_crash (__main__.TestTorch)": 111.30388899644215, - "test_terminate_signal (__main__.ForkTest)": 132.3458870516883, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 132.2043343567186, - "test_terminate_signal (__main__.SpawnTest)": 136.1005539894104, - "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.20899939537048, - "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 63.82099969046457, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 61.925000508626304, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 60.89849980672201, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 66.88233375549316, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.9854990641276, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 144.4044977823893, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.19166437784831, - "test_unary_ops (__main__.TestTEFuserDynamic)": 96.32655514611139, - "test_unary_ops (__main__.TestTEFuserStatic)": 105.33362591266632, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.8336664835612, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 82.86566925048828, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 68.26500002543132, - "test_views1_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 97.1120007832845, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 88.24766794840495, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 65.41266759236653, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 74.75533294677734, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 73.52500089009602, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.85466639200847, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 98.39650090535481, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 61.39695285615467, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.88249842325847, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 73.0695006052653, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.86250114440918, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 98.63116455078125, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 94.85683314005534, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 173.00183614095053 + "EndToEndLSTM (__main__.RNNTest)": 190.48799641927084, + "MultiheadAttention (__main__.ModulesTest)": 141.2663370768229, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 82.87333234151204, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 70.6538565499442, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.34033711751302, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 171.25450134277344, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 119.71899922688802, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.35733322870163, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.64533233642578, + "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.672952016194664, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 138.04000091552734, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 172.1344985961914, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 114.02050018310547, + "test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.25642830984933, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.3350003560384, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 120.95249938964844, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 86.97774887084961, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 100.90774917602539, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 1144.3935089111328, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 222.58500061035156, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.10033162434894, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 517.1875050862631, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 113.88125228881836, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 235.77350616455078, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 74.6155014038086, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.63325119018555, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 216.2968317667643, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 153.0915012359619, + "test_cat_2k_args (__main__.TestTEFuserDynamic)": 108.80471753561869, + "test_cat_2k_args (__main__.TestTEFuserStatic)": 102.20949847949669, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 311.7026621500651, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 395.0001729329427, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 348.6218566894531, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 98.71574974060059, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.68499946594238, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 65.0557508468628, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 65.86899948120117, + "test_comprehensive_gradient_cuda_complex64 (__main__.TestDecompCUDA)": 97.15880012512207, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 103.20700073242188, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 102.74033610026042, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 460.4286702473958, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 435.62066650390625, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 287.3090057373047, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 265.1860008239746, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1235.7365112304688, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 68.20825004577637, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1281.2615051269531, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.90750026702881, + "test_comprehensive_linalg_householder_product_cuda_complex64 (__main__.TestDecompCUDA)": 79.04633331298828, + "test_comprehensive_linalg_lu_factor_ex_cuda_complex128 (__main__.TestDecompCUDA)": 68.10879821777344, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.43025207519531, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.94575023651123, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.93649864196777, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 72.46275043487549, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 64.10650062561035, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 67.03124904632568, + "test_comprehensive_linalg_svd_cuda_float64 (__main__.TestDecompCUDA)": 64.32800025939942, + "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.41353665865384, + "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 100.17661388103778, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 110.95025062561035, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 108.06550025939941, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 104.24150085449219, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.453749656677246, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 61.739999771118164, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 69.96549987792969, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 113.65749931335449, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 106.57500076293945, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.54049682617188, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 116.19766489664714, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 272.48475646972656, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 248.12175369262695, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 79.66900062561035, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 81.52649879455566, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 79.29400062561035, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.40349960327148, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 128.42924880981445, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 125.03675079345703, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1264.9732360839844, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1250.7332458496094, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1255.0684814453125, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 574.4627532958984, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 581.7282485961914, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 65.052001953125, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 61.19200134277344, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 63.16874885559082, + "test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 62.39250183105469, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 113.32574844360352, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 113.91499900817871, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 74.42549800872803, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 76.1560001373291, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 66.76750087738037, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 70.69724941253662, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.87625026702881, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 80.2542495727539, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 69.0419979095459, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 117.03342655726841, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 289.50213841029574, + "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 67.38800048828125, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 145.27399444580078, + "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 66.9245999654134, + "test_conv3d_cuda (__main__.AOTInductorTestABICompatibleGpu)": 151.91099548339844, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 92.79549789428711, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.60149955749512, + "test_conv_transpose_with_output_size_and_no_batch_dim_ConvTranspose3d_cuda (__main__.TestConvolutionNNDeviceTypeCUDA)": 69.27724676392972, + "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 76.24971498761859, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 81.93449974060059, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 78.87700080871582, + "test_count_nonzero_all (__main__.TestBool)": 631.2585144042969, + "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.042999267578125, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 84.49850082397461, + "test_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 93.03299713134766, + "test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 228.46711820714614, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 286.29998779296875, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 68.43842806134906, + "test_fail_random.py (__main__.TestTyping)": 74.83523060725285, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 72.84900093078613, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 75.86675071716309, + "test_fuse_large_params_cpu (__main__.CpuTests)": 151.4199981689453, + "test_fuse_large_params_cuda (__main__.GPUTests)": 60.351999282836914, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 158.3622828892299, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 149.6796646118164, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 139.97800064086914, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 114.8385009765625, + "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.69736822027909, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 84.62700080871582, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.197998046875, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 96.46900177001953, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 187.83824920654297, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 110.49449920654297, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 124.90424919128418, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 518.4157485961914, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 304.6440022786458, + "test_inductor_dynamic_shapes_broadcasting_dynamic_shapes (__main__.DynamicShapesReproTests)": 143.82052836698645, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 77.4985705784389, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 76.06225109100342, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 138.9222858973912, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 120.62233225504558, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 148.1219940185547, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 109.34200286865234, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 119.36233266194661, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 127.95700073242188, + "test_list_clearing_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 61.64850175380707, + "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 105.3174296787807, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 585.9210001627604, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 504.3250020345052, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 86.21566645304362, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 129.277715410505, + "test_max_autotune_addmm_max_autotune_gemm_backends_CK_x_shape2 (__main__.TestCKBackend)": 64.24800109863281, + "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_False (__main__.TestCKBackend)": 77.23899841308594, + "test_max_autotune_precompile_matmul_max_autotune_gemm_backends_CKTILE_autotune_in_subproc_False_use_aoti_True (__main__.TestCKBackend)": 65.15649795532227, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.579833984375, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.6555004119873, + "test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 142.21566772460938, + "test_proper_exit (__main__.TestDataLoader)": 267.74214717320035, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 266.6539971487863, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 101.97100067138672, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.3346659342448, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 81.50300216674805, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.61333465576172, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.41133371988933, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 73.37100219726562, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.30900065104167, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 96.61750030517578, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 79.33600234985352, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 101.2393315633138, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 103.18400192260742, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 75.4114990234375, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 96.52833302815755, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 99.72700119018555, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 100.61966705322266, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.2750015258789, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 95.17449951171875, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.96749877929688, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 106.44049835205078, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7173334757487, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 531.5236612955729, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1077.4210205078125, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 812.0880126953125, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1347.9365234375, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 88.93533070882161, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 269.01949310302734, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 131.99799601236978, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 232.36275100708008, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 69.80400085449219, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 134.3415012359619, + "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 67.51749992370605, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 91.21066792805989, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 170.97775268554688, + "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 61.608266321818036, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 110.62575149536133, + "test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 63.59499969482422, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 88.68299865722656, + "test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 91.50320053100586, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 66.10774898529053, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 66.20533180236816, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 243.1092529296875, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 105.01200103759766, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.93685695103237, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 142.38899993896484, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 119.90166600545247, + "test_sort_bool_cpu (__main__.CpuTritonTests)": 346.2856750488281, + "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 423.09974098205566, + "test_sort_stable_cuda (__main__.GPUTests)": 117.61659927368164, + "test_sort_transpose_cpu (__main__.CpuTritonTests)": 378.31200154622394, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 222.822007894516, + "test_terminate_handler_on_crash (__main__.TestTorch)": 143.31728431156702, + "test_terminate_signal (__main__.ForkTest)": 168.20485967184817, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 168.19242484867573, + "test_terminate_signal (__main__.SpawnTest)": 172.16428443363733, + "test_thnn_conv_strided_padded_dilated (__main__.TestConvolutionNN)": 93.30639710426331, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 163.89743041992188, + "test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 60.47671399797712, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.39550018310547, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 173.53924942016602, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 175.3212537765503, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 122.20649909973145, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 99.9885025024414, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 71.64024829864502, + "test_view_ops (__main__.TestViewOpsWithLocalTensor)": 73.45887422561646, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.75249862670898, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 61.858001708984375, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 65.11023766653878, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 66.35274982452393, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 61.196499824523926, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.75380906604585, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 73.64649868011475, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 75.09799966358003, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.51450157165527, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21433276221866, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 73.20024871826172, + "test_vmapvjpvjp_linalg_lstsq_cuda_float32 (__main__.TestOperatorsCUDA)": 88.1349983215332, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 76.89924907684326, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 77.32975196838379, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 120.09600067138672 } \ No newline at end of file From 59307ca1bc256ad03140c3c55698763c75487c63 Mon Sep 17 00:00:00 2001 From: Angel Li Date: Mon, 10 Nov 2025 14:46:42 +0000 Subject: [PATCH 282/651] [BE] adding documentation (#167334) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `torch.ao.quantization` and `torch.fx.experimental` Screenshot 2025-11-07 at 3 20 54 PM Screenshot 2025-11-07 at 3 20 45 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/167334 Approved by: https://github.com/janeyx99 --- docs/source/conf.py | 121 ---------------------------- docs/source/fx.experimental.md | 92 +++++++++++++++++++++ docs/source/fx.md | 2 - docs/source/quantization-support.md | 41 ++++++++++ torch/ao/quantization/fx/utils.py | 10 ++- 5 files changed, 139 insertions(+), 127 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 9a06c0e2036d2..99ce1e0b8db5d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -382,20 +382,6 @@ # torch.ao.quantization.backend_config.tensorrt "get_tensorrt_backend_config", "get_tensorrt_backend_config_dict", - # torch.ao.quantization.backend_config.utils - "entry_to_pretty_str", - "get_fused_module_classes", - "get_fuser_method_mapping", - "get_fusion_pattern_to_extra_inputs_getter", - "get_fusion_pattern_to_root_node_getter", - "get_module_to_qat_module", - "get_pattern_to_dtype_configs", - "get_pattern_to_input_type_to_index", - "get_qat_module_classes", - "get_root_module_to_quantized_reference_module", - "pattern_to_human_readable", - "remove_boolean_dispatch_from_name", - # torch.ao.quantization.backend_config.x86 "get_x86_backend_config", # torch.ao.quantization.fuse_modules "fuse_known_modules", @@ -426,25 +412,6 @@ "insert_observers_for_model", "prepare", "propagate_dtypes_for_known_nodes", - # torch.ao.quantization.fx.utils - "all_node_args_except_first", - "all_node_args_have_no_tensors", - "assert_and_get_unique_device", - "collect_producer_nodes", - "create_getattr_from_value", - "create_node_from_old_node_preserve_meta", - "get_custom_module_class_keys", - "get_linear_prepack_op_for_dtype", - "get_new_attr_name_with_prefix", - "get_non_observable_arg_indexes_and_types", - "get_qconv_prepack_op", - "get_skipped_module_name_and_classes", - "graph_module_from_producer_nodes", - "maybe_get_next_module", - "node_arg_is_bias", - "node_arg_is_weight", - "return_arg_list", - # torch.ao.quantization.pt2e.graph_utils "bfs_trace_with_node_process", "find_sequential_partitions", "get_equivalent_types", @@ -860,80 +827,10 @@ "get_latency_of_one_partition", "get_latency_of_partitioned_graph", "get_partition_to_latency_mapping", - # torch.fx.experimental.proxy_tensor - "decompose", - "disable_autocast_cache", - "disable_proxy_modes_tracing", - "dispatch_trace", - "extract_val", - "fake_signature", - "fetch_sym_proxy", - "fetch_object_proxy", - "get_innermost_proxy_mode", - "get_isolated_graphmodule", - "get_proxy_slot", - "get_torch_dispatch_modes", - "has_proxy_slot", - "is_sym_node", - "maybe_handle_decomp", - "proxy_call", - "set_meta", - "set_original_aten_op", - "set_proxy_slot", - "snapshot_fake", - "thunkify", - "track_tensor", - "track_tensor_tree", - "wrap_key", - "wrapper_and_args_for_make_fx", - # torch.fx.experimental.recording "record_shapeenv_event", "replay_shape_env_events", "shape_env_check_state_equal", - # torch.fx.experimental.sym_node - "ceil_impl", - "floor_ceil_helper", - "floor_impl", - "method_to_operator", - "sympy_is_channels_last_contiguous_2d", - "sympy_is_channels_last_contiguous_3d", - "sympy_is_channels_last_strides_2d", - "sympy_is_channels_last_strides_3d", - "sympy_is_channels_last_strides_generic", - "sympy_is_contiguous", - "sympy_is_contiguous_generic", - "to_node", - "wrap_node", "sym_sqrt", - # torch.fx.experimental.symbolic_shapes - "bind_symbols", - "cast_symbool_to_symint_guardless", - "create_contiguous", - "error", - "eval_guards", - "eval_is_non_overlapping_and_dense", - "expect_true", - "find_symbol_binding_fx_nodes", - "free_symbols", - "free_unbacked_symbols", - "fx_placeholder_targets", - "fx_placeholder_vals", - "guard_bool", - "guard_float", - "guard_int", - "guard_scalar", - "has_hint", - "has_symbolic_sizes_strides", - "is_channels_last_contiguous_2d", - "is_channels_last_contiguous_3d", - "is_channels_last_strides_2d", - "is_channels_last_strides_3d", - "is_contiguous", - "is_non_overlapping_and_dense_indicator", - "is_nested_int", - "is_symbol_binding_fx_node", - "is_symbolic", - # torch.fx.experimental.unification.core "reify", # torch.fx.experimental.unification.match "edge", @@ -971,24 +868,6 @@ "reverse_dict", # torch.fx.experimental.unification.multipledispatch.variadic "isvariadic", - # torch.fx.experimental.unification.unification_tools - "assoc", - "assoc_in", - "dissoc", - "first", - "get_in", - "getter", - "groupby", - "itemfilter", - "itemmap", - "keyfilter", - "keymap", - "merge", - "merge_with", - "update_in", - "valfilter", - "valmap", - # torch.fx.experimental.unification.utils "freeze", "hashable", "raises", diff --git a/docs/source/fx.experimental.md b/docs/source/fx.experimental.md index cba695b5e1c55..79cfaff7d0f2d 100644 --- a/docs/source/fx.experimental.md +++ b/docs/source/fx.experimental.md @@ -12,6 +12,37 @@ These APIs are experimental and subject to change without notice. .. autoclass:: torch.fx.experimental.sym_node.DynamicInt ``` +## torch.fx.experimental.sym_node + +```{eval-rst} +.. currentmodule:: torch.fx.experimental.sym_node +``` + +```{eval-rst} +.. automodule:: torch.fx.experimental.sym_node +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + is_channels_last_contiguous_2d + is_channels_last_contiguous_3d + is_channels_last_strides_2d + is_channels_last_strides_3d + is_contiguous + is_non_overlapping_and_dense_indicator + method_to_operator + sympy_is_channels_last_contiguous_2d + sympy_is_channels_last_contiguous_3d + sympy_is_channels_last_strides_2d + sympy_is_channels_last_strides_3d + sympy_is_channels_last_strides_generic + sympy_is_contiguous + sympy_is_contiguous_generic +``` + ## torch.fx.experimental.symbolic_shapes ```{eval-rst} @@ -69,6 +100,25 @@ These APIs are experimental and subject to change without notice. rebind_unbacked resolve_unbacked_bindings is_accessor_node + cast_symbool_to_symint_guardless + create_contiguous + error + eval_guards + eval_is_non_overlapping_and_dense + find_symbol_binding_fx_nodes + free_symbols + free_unbacked_symbols + fx_placeholder_targets + fx_placeholder_vals + guard_bool + guard_float + guard_int + guard_scalar + has_hint + has_symbolic_sizes_strides + is_nested_int + is_symbol_binding_fx_node + is_symbolic ``` ## torch.fx.experimental.proxy_tensor @@ -91,4 +141,46 @@ These APIs are experimental and subject to change without notice. get_proxy_mode maybe_enable_thunkify maybe_disable_thunkify + decompose + disable_autocast_cache + disable_proxy_modes_tracing + extract_val + fake_signature + fetch_object_proxy + fetch_sym_proxy + has_proxy_slot + is_sym_node + maybe_handle_decomp + proxy_call + set_meta + set_original_aten_op + set_proxy_slot + snapshot_fake ``` + +## torch.fx.experimental.unification.unification_tools + +```{eval-rst} +.. currentmodule:: torch.fx.experimental.unification.unification_tools +``` + +```{eval-rst} +.. automodule:: torch.fx.experimental.unification.unification_tools +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + assoc + assoc_in + dissoc + first + keyfilter + keymap + merge + merge_with + update_in + valfilter + valmap diff --git a/docs/source/fx.md b/docs/source/fx.md index c9c235382893e..b8447b378d3f9 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -1134,7 +1134,6 @@ The set of leaf modules can be customized by overriding .. py:module:: torch.fx.experimental.refinement_types .. py:module:: torch.fx.experimental.rewriter .. py:module:: torch.fx.experimental.schema_type_annotation -.. py:module:: torch.fx.experimental.sym_node .. py:module:: torch.fx.experimental.unification.core .. py:module:: torch.fx.experimental.unification.dispatch .. py:module:: torch.fx.experimental.unification.match @@ -1144,7 +1143,6 @@ The set of leaf modules can be customized by overriding .. py:module:: torch.fx.experimental.unification.multipledispatch.dispatcher .. py:module:: torch.fx.experimental.unification.multipledispatch.utils .. py:module:: torch.fx.experimental.unification.multipledispatch.variadic -.. py:module:: torch.fx.experimental.unification.unification_tools .. py:module:: torch.fx.experimental.unification.utils .. py:module:: torch.fx.experimental.unification.variable .. py:module:: torch.fx.experimental.unify_refinements diff --git a/docs/source/quantization-support.md b/docs/source/quantization-support.md index 3bb5c45face69..0b5d338d6f2bb 100644 --- a/docs/source/quantization-support.md +++ b/docs/source/quantization-support.md @@ -134,6 +134,23 @@ Quantization to work with this as well. ObservationType ``` +## torch.ao.quantization.backend_config.utils +```{eval-rst} +.. currentmodule:: torch.ao.quantization.backend_config.utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + entry_to_pretty_str + pattern_to_human_readable + remove_boolean_dispatch_from_name + +``` + ## torch.ao.quantization.fx.custom_config This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization @@ -154,6 +171,30 @@ This module contains a few CustomConfig classes that's used in both eager mode a StandaloneModuleConfigEntry ``` +## torch.ao.quantization.fx.utils + +```{eval-rst} +.. currentmodule:: torch.ao.quantization.fx.utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + all_node_args_except_first + all_node_args_have_no_tensors + collect_producer_nodes + create_getattr_from_value + create_node_from_old_node_preserve_meta + graph_module_from_producer_nodes + maybe_get_next_module + node_arg_is_bias + node_arg_is_weight + return_arg_list +``` + ## torch.ao.quantization.quantizer ```{eval-rst} diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 9f76f2a328df1..f173135013d7b 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -195,10 +195,12 @@ def get_attr_name(i: int): def collect_producer_nodes(node: Node) -> Optional[list[Node]]: r"""Starting from a target node, trace back until we hit input or getattr node. This is used to extract the chain of operators - starting from getattr to the target node, for example - def forward(self, x): - observed = self.observer(self.weight) - return F.linear(x, observed) + starting from getattr to the target node, for example:: + + def forward(self, x): + observed = self.observer(self.weight) + return F.linear(x, observed) + collect_producer_nodes(observed) will either return a list of nodes that produces the observed node or None if we can't extract a self contained graph without free variables(inputs of the forward function). From 31ccd8f13e9694850cd46b4bfdaa98555f7f9acd Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 7 Nov 2025 12:49:00 -0800 Subject: [PATCH 283/651] [AOTI] Fix a mixed-device bug for scatter_add (#167341) Summary: Fix https://github.com/pytorch/pytorch/issues/166841. AOTI incorrectly generates a call to aoti_torch_cuda_scatter_reduce_two_out while the op should actually run on CPU. Fix by using the correct device when calling _generate_scatter_fallback in the wrapper codegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167341 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 32 +++++++++++++++++++ test/inductor/test_aot_inductor_utils.py | 3 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 8 +++-- .../codegen/cpp_wrapper_cpu_array_ref.py | 5 ++- torch/_inductor/codegen/wrapper.py | 3 ++ 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 4f7eb86e8ce47..7322f1e78dcb2 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -7522,6 +7522,38 @@ def forward(self, x, y, a, b): eager_outputs = model(*example_inputs) torch.testing.assert_close(eager_outputs, compiled_outputs) + @requires_gpu + def test_mixed_device_1(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("Mixed-device test requires GPU") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + # Buffers are on CPU + self.register_buffer( + "index", torch.tensor([1, 4, 1, 7], device="cpu", dtype=torch.int64) + ) + self.register_buffer( + "src", torch.ones(4, device="cpu", dtype=torch.int64) + ) + + def forward(self, matrix, vector): + # Inputs are on CUDA + # 1. Operation on CPU tensors + z = torch.zeros((vector.shape[0],), device="cpu", dtype=torch.int64) + scatter_result = z.scatter_add(0, self.index, self.src) + + # 2. Move result to CUDA and continue on CUDA + v = vector + scatter_result.to(vector.dtype).to(GPU_TYPE) + return torch.matmul(matrix, v) + + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, device=self.device), + ) + self.check_model(Model(), example_inputs, move_model_to_device=False) + class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 2a9f593c5a6c4..cb16f46a752b8 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -218,6 +218,7 @@ def check_model( dynamic_shapes=None, atol=None, rtol=None, + move_model_to_device=True, ): with ( torch.no_grad(), @@ -229,7 +230,7 @@ def check_model( ), ): torch.manual_seed(0) - if not isinstance(model, types.FunctionType): + if not isinstance(model, types.FunctionType) and move_model_to_device: model = model.to(self.device) # For non mixed device inputs with default "cpu",set the device manually. diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index be87044a74e1c..61a97fd740cbc 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -221,7 +221,9 @@ def write_header(self): """ ) - self.add_device_include(self.device) + for device in V.graph.device_types: + if device != "meta": + self.add_device_include(device) if V.graph.aot_mode: if config.aot_inductor.dynamic_linkage: @@ -1423,11 +1425,13 @@ def _generate_scatter_fallback( src_is_tensor, reduce, kwargs, + device, ): reduce = self._get_scatter_reduce_enum(reduce) # call the ABI shim function instead of the ATen one - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + self.add_device_include(device) + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device) # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" inputs_wrapped = [str(x) for x in inputs] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 11e74b9ddf8b8..c0c9aef609ba4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -708,11 +708,14 @@ def _generate_scatter_fallback( src_is_tensor, reduce, kwargs, + device, ): reduce = self._get_scatter_reduce_enum(reduce) # call the ABI shim function instead of the ATen one - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + self.add_device_include(device) + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 947166cf216cd..50b23ece7ffa3 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -971,6 +971,7 @@ def codegen(self, code: IndentedBuffer) -> None: else: (x, index) = (t.codegen_reference() for t in node.inputs) src = node.constant_args[1] + device = d.type if (d := node.get_device()) else V.graph.device_type self.wrapper._generate_scatter_fallback( x, [x, node.constant_args[0], index, src], @@ -979,6 +980,7 @@ def codegen(self, code: IndentedBuffer) -> None: node.src_is_tensor, node.kwargs["reduce"], node.codegen_kwargs(), + device, ) def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: @@ -1632,6 +1634,7 @@ def _generate_scatter_fallback( src_is_tensor, reduce, kwargs, + device, ): line = f"{python_kernel_name}({','.join(map(str, inputs))}" if python_kernel_name.startswith("aten.scatter_reduce"): From 2fcf41dd8ef53b1bab3e8dd5171479019e470039 Mon Sep 17 00:00:00 2001 From: albanD Date: Mon, 10 Nov 2025 17:10:12 +0000 Subject: [PATCH 284/651] Add the ruff rule and skip everything for now (#167360) Part of https://github.com/pytorch/pytorch/issues/164878 We can start narrowing the skips and remove them as PRs keep landing. This PR is just to setup the scaffolding, fix will be in follow up Pull Request resolved: https://github.com/pytorch/pytorch/pull/167360 Approved by: https://github.com/janeyx99 --- pyproject.toml | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b01ba623cc814..b4d7a06d3f40f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -260,6 +260,7 @@ select = [ "TRY401", # verbose-log-message "UP", "YTT", + "S101", ] [tool.ruff.lint.pyupgrade] @@ -339,6 +340,39 @@ keep-runtime-typing = true "tools/linter/**" = [ "LOG015" # please fix ] +"benchmarks/**" = [ + "S101" +] +"test/**" = [ + "S101" +] +"torchgen/**" = [ + "S101" +] +"torch/**" = [ + "S101" +] +"tools/**" = [ + "S101" +] +"setup.py" = [ + "S101" +] +"functorch/**" = [ + "S101" +] +"docs/**" = [ + "S101" +] +"android/**" = [ + "S101" +] +".github/**" = [ + "S101" +] +".ci/**" = [ + "S101" +] [tool.codespell] ignore-words = "tools/linter/dictionary.txt" From f6a79b2a4ac92a9077fc03cb124e3bc0a0a28730 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 9 Nov 2025 20:35:53 -0800 Subject: [PATCH 285/651] [inductor] Wrap pallas_call in jax.jit (#167441) My understanding is this is needed for performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167441 Approved by: https://github.com/oulgen --- test/inductor/test_pallas.py | 28 ++++++++++++++++++ torch/_inductor/codegen/pallas.py | 49 ++++++++++++++++--------------- 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index 8571321e872e6..ed9a8edd79a87 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: pt2"] import functools +import re import sys import unittest @@ -230,6 +231,33 @@ def pallas_fn(a, b): self.assertIn("import jax.numpy as jnp", code) self.assertIn("from jax.experimental import pallas as pl", code) + def test_jax_jit_wrapper_is_emitted(self): + """Ensure generated Pallas code wraps pl.pallas_call in jax.jit.""" + + key = "cuda_backend" if self.DEVICE == "cuda" else "cpu_backend" + + @torch.compile(backend="inductor", options={key: "pallas"}) + def pallas_fn(a, b): + return a + b + + _, (code,) = run_and_get_code( + pallas_fn, + torch.randn(32, device=self.DEVICE), + torch.randn(32, device=self.DEVICE), + ) + + kernel_match = re.search(r"def (pallas_[A-Za-z0-9_]+)_kernel", code) + self.assertIsNotNone(kernel_match) + kernel_name = kernel_match.group(1) + wrapper_name = f"{kernel_name}_jit_wrapper" + self.assertIn(wrapper_name, code) + start = code.index(f"def {wrapper_name}") + end = code.index(f"def {kernel_name}_main", start) + wrapper_block = code[start:end] + + self.assertIn("jax.jit", code) + self.assertNotIn("torch.", wrapper_block) + def test_2d_tensor(self): """Test with 2D tensors (though current implementation flattens).""" diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index e5bf1fa17cdca..eb69bb842d3dd 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -287,6 +287,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove code = IndentedBuffer() code.splice( """ + import functools import torch import jax import jax.numpy as jnp @@ -301,6 +302,9 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove kernel_params = [a.name for a in arg_defs] kernel_name = name or "" + interpret_literal = ( + "True" if V.graph.get_current_device_or_throw().type == "cpu" else "False" + ) code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):") with code.indent(): # Emit compute (CSE) and store lines; they reference *_ptr[...] directly @@ -309,16 +313,22 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove for line in self.stores._lines: code.writeline(str(line)) + jit_wrapper_name = f"{kernel_name}_jit_wrapper" + code.writeline("@functools.partial(jax.jit, static_argnums=(0, 1))") + code.writeline(f"def {jit_wrapper_name}(out_shape, out_dtype, *kernel_refs):") + with code.indent(): + code.writeline("out_spec = jax.ShapeDtypeStruct(out_shape, out_dtype)") + code.writeline("return pl.pallas_call(") + code.writeline(f" {kernel_name}_kernel,") + code.writeline(" out_shape=out_spec,") + code.writeline(f" interpret={interpret_literal},") + code.writeline(" grid=(1,),") + code.writeline(")(*kernel_refs)") + # Host entry: convert torch tensors <-> jax, call pallas_call and copy back main_name = f"{kernel_name}_main" code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):") with code.indent(): - # Determine interpret statically based on codegen device - interpret_literal = ( - "True" - if V.graph.get_current_device_or_throw().type == "cpu" - else "False" - ) # Identify inputs (in_ptr*) and output (out_ptr*) input_params = [ p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr")) @@ -337,9 +347,9 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove for inp in input_params: code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})") - # Get output spec from PyTorch tensor - code.writeline("# Prepare output spec from PyTorch tensor") - code.writeline("# Map PyTorch dtype to JAX dtype string") + # Get output metadata from PyTorch tensor + code.writeline("# Prepare output metadata from PyTorch tensor") + code.writeline("# Map PyTorch dtype to JAX dtype") code.writeline("_torch_dtype_to_jax = {") code.writeline( " torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16," @@ -349,21 +359,14 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[ove ) code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,") code.writeline("}") - code.writeline( - f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])" - ) + code.writeline(f"out_shape = tuple({output_param}.shape)") + code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]") - # Call pallas - # Pass interpret=True on CPU, False otherwise (single call, no duplication) - code.writeline("compiled = pl.pallas_call(") - code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),") - code.writeline(" out_shape=out_spec,") - code.writeline(f" interpret={interpret_literal},") - code.writeline(" grid=(1,),") - code.writeline(")") - - jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params]) - code.writeline(f"res = compiled({jax_input_args})") + call_args = ["out_shape", "out_dtype"] + [ + f"{inp}_jax" for inp in input_params + ] + call_arg_str = ", ".join(call_args) + code.writeline(f"res = {jit_wrapper_name}({call_arg_str})") # Copy result back code.writeline("# Copy result back into the provided torch output tensor") From 3966b5ad05e842467120efb7aeb41967d2e88c47 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 9 Nov 2025 21:27:00 -0800 Subject: [PATCH 286/651] [BE] Fix out-of-bounds index_put in test_mps.py (#167444) Discovered while enabling assertions on out-of-bounds accesses. Otherwise test fails with ``` ERROR: test_sdpa_mask_fp16_L6_S17_NH23_HS121 (__main__.TestSDPA.test_sdpa_mask_fp16_L6_S17_NH23_HS121) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3334, in wrapper method(*args, **kwargs) ~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/build/../test/test_mps.py", line 9494, in test_sdpa_mask_fp16_L6_S17_NH23_HS121 self._test_sdpa_mask(torch.float16, 7, 17, 23, 121) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/build/../test/test_mps.py", line 9478, in _test_sdpa_mask y_ref = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), attn_mask=mask.cpu(), dropout_p=0.0, is_causal=False) ~~~~~^^ torch.AcceleratorError: index out of range ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167444 Approved by: https://github.com/Skylion007, https://github.com/manuelcandales --- test/test_mps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_mps.py b/test/test_mps.py index ca95839e7a7fb..76991f48e7cdc 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9465,7 +9465,7 @@ def _test_sdpa_mask(self, dtype: torch.dtype, L: int = 1, S: int = 72, NH: int = torch.manual_seed(1729) causal_mask = torch.tril(torch.ones(S, S, dtype=torch.bool, device='mps')) with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): - i = 42 + i = 42 if S > 42 else S // 2 q = torch.randn([1, NH, L, HS], dtype=dtype, device="mps") k = torch.randn([1, NH, S, HS], dtype=q.dtype, device="mps") From 3ea829a3374322c4072a6e58ea7625d6b022989f Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 10 Nov 2025 18:19:35 +0000 Subject: [PATCH 287/651] Fix torch.cond HOP device in inductor (#167354) Fixes #166918 The output device may not be on the same device as the predicate device. ``` python test/inductor/test_control_flow.py -k test_output_on_different_device ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167354 Approved by: https://github.com/ydwu4, https://github.com/zou3519 --- test/inductor/test_control_flow.py | 34 +++++++++++++++++++++++++----- torch/_inductor/ir.py | 4 +++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index a3c81bdfd15b0..b3d1a5f2529a1 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -20,9 +20,11 @@ from torch.testing._internal.triton_utils import requires_gpu -def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1): +def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1, device=None): result = [] - device = inputs[0].device + if len(inputs) != 0: + device = inputs[0].device + assert device # iterate over the cartesian product of predicate values for values in itertools.product(*([possible_values] * num_to_prepend)): prepended = [torch.tensor(v, device=device) for v in values] @@ -30,8 +32,8 @@ def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1): return result -def prepend_predicates(inputs, num_predicates=1): - return _prepend_product_of_values(inputs, [False, True], num_predicates) +def prepend_predicates(inputs, num_predicates=1, device=None): + return _prepend_product_of_values(inputs, [False, True], num_predicates, device) def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)): @@ -308,7 +310,9 @@ def _run_test( torch._dynamo.mark_dynamic(inp, 0) for inputs in input_sets: - for inputs_with_predicates in prepend_predicates(inputs, num_predicates): + for inputs_with_predicates in prepend_predicates( + inputs, num_predicates, device=device + ): cloned_inputs = [inp.clone() for inp in inputs_with_predicates] result = model(*inputs_with_predicates) result_compiled = compiled_model(*inputs_with_predicates) @@ -768,6 +772,26 @@ def test_cond_select_with_input_idx(self, device, dynamic): dynamic=dynamic, ) + @requires_gpu + def test_output_on_different_device(self): + class FactoryBranches(torch.nn.Module): + def forward(self, pred): + tensor = torch.cond( + pred, + lambda: torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).to( + GPU_TYPE + ), + lambda: torch.zeros(5, dtype=torch.float32).to(GPU_TYPE), + ) + return tensor + 1 + + self._run_test( + model=FactoryBranches(), + inputs=(), + device="cpu", # device for predicate + dynamic=True, + ) + class WhileLoopModels: class Simple(torch.nn.Module): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 53c12d0726044..43952a11f2da4 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -8845,7 +8845,9 @@ def create( outputs = [ MultiOutput( FixedLayout( - device=device, + device=output.get_device() + if output.get_device() is not None + else device, # type: ignore[arg-type] dtype=output.get_dtype(), size=[Conditional._maybe_expr(sz) for sz in merged_output.size()], stride=[ From a4437d76f0fad9f9bdbac0e7c2e365a002fe26fa Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 10 Nov 2025 18:38:42 +0000 Subject: [PATCH 288/651] Add some labeler rules that used to be in the autolabel bot (#167330) See https://github.com/pytorch/test-infra/pull/7446 for the paths Pull Request resolved: https://github.com/pytorch/pytorch/pull/167330 Approved by: https://github.com/huydhn --- .github/labeler.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index 246ddd8614396..e8d3c223af317 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -165,3 +165,16 @@ - torch/_inductor/kernel/mm.py - test/inductor/test_max_autotune.py - third_party/fbgemm + +"ciflow/mps": +- aten/src/ATen/mps/** +- aten/src/ATen/native/mps/** +- torch/_inductor/codegen/mps.py +- test/test_mps.py +- test/inductor/test_mps_basic.py + +"ciflow/h100-symm-mem": +- torch/csrc/distributed/c10d/symm_mem/** +- torch/distributed/_symmetric_memory/** +- test/distributed/**/*mem* +- test/distributed/**/*mem*/** From 04a85b4c218b9dc9fd65b148f4b15114925ca4af Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Sun, 9 Nov 2025 14:04:43 -0800 Subject: [PATCH 289/651] [compile-on-one-rank] Step 1: DeviceId (#166680) Add a "--virtual-local-rank" mode to torchrun. When used instead of passing the local rank in LOCAL_RANK it uses a LOCAL_RANK of "0" and adjusts CUDA_VISIBLE_DEVICES to reflect the desired GPU index. Testing: (tweaked run_train.sh to use `--log-dir`) ``` export NGPU=8 export CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name compiler_toolkit.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 ``` And then comparing ranks: Without --virtual-local-rank gives a lot of differences like: ``` [rank#]: mul_1: "f32[8, 512, 256]" = torch.ops.aten.mul.Tensor(mul, view_9); mul = None -[rank#]: _to_copy_3: "bf16[8, 512, 256]" = torch.ops.aten._to_copy.default(mul_1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0)); mul_1 = None +[rank#]: _to_copy_3: "bf16[8, 512, 256]" = torch.ops.aten._to_copy.default(mul_1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=1)); mul_1 = None [rank#]: detach: "f32[8, 512, 1]" = torch.ops.aten.detach.default(rsqrt); rsqrt = None ``` With --virtual-local-rank makes those differences go away. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166680 Approved by: https://github.com/ezyang --- test/distributed/launcher/script_deviceid.py | 44 +++++++++ test/distributed/launcher/test_run.py | 93 ++++++++++++++++++- torch/distributed/elastic/agent/server/api.py | 8 +- .../agent/server/local_elastic_agent.py | 42 ++++++++- torch/distributed/launcher/api.py | 6 ++ torch/distributed/run.py | 10 ++ torch/testing/_internal/common_distributed.py | 2 +- 7 files changed, 201 insertions(+), 4 deletions(-) create mode 100644 test/distributed/launcher/script_deviceid.py diff --git a/test/distributed/launcher/script_deviceid.py b/test/distributed/launcher/script_deviceid.py new file mode 100644 index 0000000000000..1a09cdc102633 --- /dev/null +++ b/test/distributed/launcher/script_deviceid.py @@ -0,0 +1,44 @@ +# Owner(s): ["oncall: r2p"] + +# This is a helper script for +# test_run.py::ElasticLaunchTest::test_virtual_local_rank. It prints out the +# generated inductor output for a simple function. + +import os +from unittest.mock import patch + +import torch +import torch.distributed as dist +from torch._inductor import codecache + + +@torch.compile +def myfn(x: torch.Tensor) -> torch.Tensor: + return x + x + + +dist.init_process_group(backend="nccl") + +local_rank = int(os.environ.get("LOCAL_RANK", "cuda:0")) +torch.cuda.set_device(local_rank) + + +def print_output_code(original_fn): + def wrapper(msg, *args, **kwargs): + # Check if this is the "Output code:" message + if args and "Output code:" in msg: + print(args[0]) + + return wrapper + + +x = torch.rand(2, 2, device="cuda") + +with patch.object( + codecache.output_code_log, + "debug", + side_effect=print_output_code(codecache.output_code_log.debug), +): + y = myfn(x) + +dist.destroy_process_group() diff --git a/test/distributed/launcher/test_run.py b/test/distributed/launcher/test_run.py index 50e2d53928c04..484a975051d4f 100644 --- a/test/distributed/launcher/test_run.py +++ b/test/distributed/launcher/test_run.py @@ -16,7 +16,7 @@ import tempfile import uuid from contextlib import closing, redirect_stderr, redirect_stdout -from unittest import mock +from unittest import mock, skipIf from unittest.mock import MagicMock, Mock, patch import torch.distributed.run as launch @@ -28,6 +28,7 @@ from torch.testing._internal.common_utils import ( run_tests, skip_but_pass_in_sandcastle_if, + TEST_CUDA, TEST_WITH_DEV_DBG_ASAN, TestCase, ) @@ -677,6 +678,96 @@ def test_capture_logs_using_default_logs_specs(self): for i in range(nproc_per_node): self.assertTrue(f"[rank{i}]: creating " in captured_out.getvalue()) + @skip_but_pass_in_sandcastle_if( + TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan" + ) + @skipIf(not TEST_CUDA, "requires CUDA") + def test_virtual_local_rank(self): + """ + Test that virtual-local-rank ensures consistent device IDs across ranks. + Without it, ranks may compile to different devices, leading to different code. + """ + run_id = str(uuid.uuid4().int) + nnodes = 1 + nproc_per_node = 2 + + # Helper function to run and capture output + def run_test(use_virtual_local_rank): + args = [ + f"--nnodes={nnodes}", + f"--nproc-per-node={nproc_per_node}", + f"--rdzv-id={run_id}", + "--monitor-interval=1", + "--start-method=spawn", + "--redirect=3", + "--tee=3", + ] + if use_virtual_local_rank: + args.append("--virtual-local-rank") + + args.append(path("script_deviceid.py")) + + captured_out = io.StringIO() + captured_err = io.StringIO() + with redirect_stdout(captured_out), redirect_stderr(captured_err): + launch.main(args) + + return captured_out.getvalue() + + def split_ranks(output): + default0 = [] + default1 = [] + for line in output.splitlines(): + if "cuda:" not in line: + continue + if line.startswith("[default0]:"): + default0.append(line[11:]) + elif line.startswith("[default1]:"): + default1.append(line[11:]) + return default0, default1 + + # First, run WITHOUT virtual-local-rank - outputs should differ + output = run_test(use_virtual_local_rank=False) + rank0, rank1 = split_ranks(output) + + # Verify we actually captured compiled code from both ranks + self.assertGreater( + len(rank0), 0, "Expected to capture compiled code from rank 0" + ) + self.assertGreater( + len(rank1), 0, "Expected to capture compiled code from rank 1" + ) + + # Without virtual-local-rank, the ranks should have DIFFERENT compiled code + # because they see different device IDs (cuda:0 vs cuda:1) + self.assertNotEqual( + rank0, + rank1, + "Expected different compiled code without --virtual-local-rank", + ) + + # Now run WITH virtual-local-rank - outputs should be identical + output = run_test(use_virtual_local_rank=True) + rank0, rank1 = split_ranks(output) + + # Verify we actually captured compiled code from both ranks + self.assertGreater( + len(rank0), + 0, + "Expected to capture compiled code from rank 0 with --virtual-local-rank", + ) + self.assertGreater( + len(rank1), + 0, + "Expected to capture compiled code from rank 1 with --virtual-local-rank", + ) + + # With virtual-local-rank, both ranks should have IDENTICAL compiled code + # because they both see cuda:0 during compilation + self.assertEqual( + rank0, rank1, "Expected identical compiled code with --virtual-local-rank" + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index d56d61e7eaac2..1122913ed95db 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -48,7 +48,8 @@ @dataclass class WorkerSpec: - """Blueprint information about a particular type of worker. + """ + Blueprint information about a particular type of worker. For a given role, there must only exist a single worker spec. Worker spec is expected to be homogeneous across all nodes (machine), @@ -79,6 +80,10 @@ class WorkerSpec: that match _any_ of the filter strings. duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. """ role: str @@ -97,6 +102,7 @@ class WorkerSpec: numa_options: Optional[NumaOptions] = None duplicate_stdout_filters: Optional[list[str]] = None duplicate_stderr_filters: Optional[list[str]] = None + virtual_local_rank: bool = False def __post_init__(self): assert self.local_world_size > 0 diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index f643de5f9b25d..5fd3b7d3526db 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -303,7 +303,6 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: for worker in worker_group.workers: local_rank = worker.local_rank worker_env = { - "LOCAL_RANK": str(local_rank), "RANK": str(worker.global_rank), "GROUP_RANK": str(worker_group.group_rank), "ROLE_RANK": str(worker.role_rank), @@ -322,6 +321,7 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: "TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1) ), } + self._set_local_rank_env(worker_env, local_rank, spec) if "OMP_NUM_THREADS" in os.environ: worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"] @@ -362,6 +362,46 @@ def _start_workers(self, worker_group: WorkerGroup) -> dict[int, Any]: return self._pcontext.pids() + def _set_local_rank_env( + self, worker_env: dict[str, str | None], local_rank: int, spec: WorkerSpec + ) -> None: + # Set CUDA_VISIBLE_DEVICES and LOCAL_RANK based on virtual_local_rank mode. + # Virtual mode: Each worker sees only its assigned GPU as device 0, LOCAL_RANK=0 + # Traditional mode: Workers see all GPUs, LOCAL_RANK matches actual local rank + + if spec.virtual_local_rank: + # Set LOCAL_RANK=0 and use CUDA_VISIBLE_DEVICES to control the actual GPU access. + + worker_env["LOCAL_RANK"] = "0" + + # Map local_rank through existing CUDA_VISIBLE_DEVICES + # HIP uses CUDA_VISIBLE_DEVICES as a compatibility hack: + # https://rocm.docs.amd.com/en/latest/conceptual/gpu-isolation.html#cuda-visible-devices + parent_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") + if parent_visible_devices is not None: + # Parse comma-separated list of GPU IDs + available_gpus = parent_visible_devices.split(",") + if local_rank >= len(available_gpus): + raise ValueError( + f"local_rank {local_rank} exceeds available GPUs in " + f"CUDA_VISIBLE_DEVICES={parent_visible_devices}" + ) + + visible_gpu = available_gpus[local_rank].strip() + else: + # No restriction, use local_rank directly + visible_gpu = str(local_rank) + + worker_env["CUDA_VISIBLE_DEVICES"] = visible_gpu + return + + # In traditional mode, don't override CUDA_VISIBLE_DEVICES + # (inherit from parent environment) + worker_env["LOCAL_RANK"] = str(local_rank) + + if "CUDA_VISIBLE_DEVICES" in os.environ: + worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] + def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None: if self._worker_watchdog is not None: self._worker_watchdog.stop() diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index b75db1b11abbc..666fb24463f0d 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -75,6 +75,10 @@ class LaunchConfig: that match _any_ of the filter strings. duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines that match _any_ of the filter strings. + virtual_local_rank: Enable virtual local rank mode for workers (defaults to False). + When enabled, LOCAL_RANK is set to 0 for all workers and + CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its + assigned GPU at device index 0. .. note:: @@ -104,6 +108,7 @@ class LaunchConfig: signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT" duplicate_stdout_filters: Optional[list[str]] = None duplicate_stderr_filters: Optional[list[str]] = None + virtual_local_rank: bool = False def __post_init__(self): default_timeout = 900 @@ -288,6 +293,7 @@ def launch_agent( numa_options=config.numa_options, duplicate_stdout_filters=config.duplicate_stdout_filters, duplicate_stderr_filters=config.duplicate_stderr_filters, + virtual_local_rank=config.virtual_local_rank, ) agent = LocalElasticAgent( diff --git a/torch/distributed/run.py b/torch/distributed/run.py index cd9820e0e10ea..2343f7bb9b74c 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -688,6 +688,15 @@ def comma_separated_list(value): "Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).", ) + parser.add_argument( + "--virtual-local-rank", + "--virtual_local_rank", + action=check_env, + help="Enable virtual local rank mode for workers. When enabled, LOCAL_RANK is set to 0 " + "for all workers and CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its " + "assigned GPU at device index 0.", + ) + # # Positional arguments. # @@ -907,6 +916,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str signals_to_handle=args.signals_to_handle, duplicate_stdout_filters=args.duplicate_stdout_filters, duplicate_stderr_filters=args.duplicate_stderr_filters, + virtual_local_rank=args.virtual_local_rank, ) with_python = not args.no_python diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index e93c346a6645d..c2b4dd57055a6 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -1711,7 +1711,7 @@ def opts(cls, high_priority_stream=False): @classmethod def _init_pg(cls, rank, world_size, rdvz_file): assert rdvz_file is not None - # rank should be local_rank for tests running on <= 8gpus which is how all these tests are designed + # rank should be local_rank for tests running on <= 8 gpus which is how all these tests are designed # and we expect LOCAL_RANK set by torchrun. Setting it lets init_device_mesh set the device without # issuing a warning os.environ["LOCAL_RANK"] = str(rank) From 9491830c7926ab5a58afdb76ab55f9ab9884c8b3 Mon Sep 17 00:00:00 2001 From: Jazlyn Li Date: Mon, 10 Nov 2025 19:29:50 +0000 Subject: [PATCH 290/651] move subgraph_has_impure_ops from `node.is_impure` into const_fold to unblock production (#167443) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: https://github.com/pytorch/pytorch/pull/166609 updates `node.is_impure` to consider a submodule as impure if submodule contains impure node. This in turn changes `graph.eliminate_dead_code()` function behavior, which does not eliminate nodes with side effects, see [pytorch documentation](https://docs.pytorch.org/docs/stable/fx.html#torch.fx.Graph.eliminate_dead_code) > Remove all dead code from the graph, based on each node’s number of users, and whether the nodes have any side effects. While this is correct that a submodule containing side-effectful ops is side-effectful and should not be dead code eliminated, some customers rely on the dead code elimination to eliminate submodules that contain impure ops which is the behavior before #166609 fix. Due to production environment constraints, we have to revert https://github.com/pytorch/pytorch/pull/166609 and move the side-effectful submodule check logic to `const_fold.py`, which will correctly **not** const-fold a submodule that contains impure ops. NOTE other call sites that use `node.is_impure()` to make decisions are still incorrectly eliminating side-effectful submodules, but we can't safely change that today. ## This pr - move `_subgraph_has_impure_op` into `fx/experimental/const_fold.py`, check and prevent const-folding an impure submodule - added a note in `node.is_impure` to highlight the incorrect behavior and context in case people go looking in the future. Test Plan: run test_fx_const_fold and all tests pass Differential Revision: D86641994 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167443 Approved by: https://github.com/jfix71 --- torch/fx/experimental/const_fold.py | 31 +++++++++++++++++++++++++++++ torch/fx/node.py | 29 +++++---------------------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 0f6460302dfbb..f494f11593410 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -177,6 +177,26 @@ def split_const_subgraphs( else: mod_traced = module + def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: + """ + Return True if a GraphModule type subgraph contains any impure op, else False. + """ + assert isinstance(module, torch.fx.GraphModule), ( + "caller should only pass GraphModule to subgraph_has_impure_ops check" + ) + for node in module.graph.nodes: + if node.op == "call_function" and node.is_impure(): + return True + if ( + # pyrefly: ignore [invalid-argument] + node.op == "call_module" + # pyrefly: ignore [not-callable] + and (submodule := module.get_submodule(node.target)) + and isinstance(submodule, torch.fx.GraphModule) + ): + return _subgraph_has_impure_ops(submodule) + return False + # Build up a list of const_nodes, defined as nodes that are themselves # get_attrs, or have all get_attr or other constant node inputs. const_nodes: set[torch.fx.Node] = set() @@ -206,6 +226,17 @@ def split_const_subgraphs( if isinstance(node.kwargs.get("fill_value", None), sympy.Expr): continue + # Skip folding submodules that have impure ops + if ( + # pyrefly: ignore [invalid-argument] + node.op == "call_module" + # pyrefly: ignore [not-callable] + and (target_mod := mod_traced.get_submodule(node.target)) + and isinstance(target_mod, torch.fx.GraphModule) + and _subgraph_has_impure_ops(target_mod) + ): + continue + # Must be a constant foldable node at this point. const_nodes.add(node) if node.op != "get_attr": diff --git a/torch/fx/node.py b/torch/fx/node.py index 272676a4e3a94..5afabe40ec341 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -754,26 +754,6 @@ def is_impure(self, impure_random: bool = True) -> bool: return self.target in _side_effectful_functions - def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: - """ - Return True if a GraphModule type subgraph contains any impure op, else False. - """ - assert isinstance(module, torch.fx.GraphModule), ( - "caller should only pass GraphModule to subgraph_has_impure_ops check" - ) - for node in module.graph.nodes: - if node.op == "call_function" and node.is_impure(impure_random): - return True - if ( - # pyrefly: ignore [invalid-argument] - node.op == "call_module" - # pyrefly: ignore [not-callable] - and (submodule := module.get_submodule(node.target)) - and isinstance(submodule, torch.fx.GraphModule) - ): - return subgraph_has_impure_ops(submodule) - return False - # Check if an impure module. if self.op == "call_module": assert self.graph.owning_module is not None, ( @@ -783,10 +763,11 @@ def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: assert target_mod is not None, ( f"Did not find expected submodule target {self.target}" ) - if isinstance(target_mod, torch.fx.GraphModule): - return subgraph_has_impure_ops(target_mod) - else: - return getattr(target_mod, "_is_impure", False) + # NOTE: here we can end up considering GraphModule submodules pure, + # even if they contain impure ops. It may not be safe to change + # because this function is used by graph.eliminate_dead_code, + # and some users depend on current elimination behavior. + return getattr(target_mod, "_is_impure", False) return False From 86130aa2caeffc762120b323775899d21a0e2a01 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 10 Nov 2025 19:51:38 +0000 Subject: [PATCH 291/651] Fix flaky memory profiler test [2] (#167268) Fixes #167037 Move the module definition outside of the unit test so when we run the unit test multiple times, the module is not re-compiled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167268 Approved by: https://github.com/angelayi --- test/test_cuda.py | 169 ++++++++++++++++++++++------------------------ 1 file changed, 81 insertions(+), 88 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 5842b0eda7422..5712187775ef6 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -7455,6 +7455,34 @@ def test_graph_external_wait_and_record(self): class TestFXMemoryProfiler(TestCase): """Tests for memory profiler augmentation with original stack traces.""" + class MLPModule(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + a = self.net1(x) + b = self.relu(a) + c = self.net2(b) + return c + + class MLPModule2(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + d = self.net1(x) + e = self.relu(d) + f = self.net2(e) + return f + def collect_frames( self, augmented_snapshot, collect_device_traces=True, collect_segments=True ): @@ -7490,99 +7518,64 @@ def collect_frames( def test_fx_memory_profiler_augmentation(self): """Test that memory snapshots are augmented with FX debug information.""" - # Create a simple model - class MLPModule(nn.Module): - def __init__(self, device): - super().__init__() - torch.manual_seed(5) - self.net1 = nn.Linear(10, 16, bias=True, device=device) - self.relu = nn.ReLU() - self.net2 = nn.Linear(16, 10, bias=True, device=device) - - def forward(self, x): - a = self.net1(x) - b = self.relu(a) - c = self.net2(b) - return c - device = "cuda" - mod = MLPModule(device) - with tempfile.TemporaryDirectory() as tmpdir: - # reset cache to start fresh - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history() - compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) - result = compiled(torch.randn(10, 10, device=device)) - augmented_snapshot = torch.cuda.memory._snapshot( - augment_with_fx_traces=True - ) - torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) - torch.cuda.empty_cache() + mod = self.MLPModule(device) + # reset cache to start fresh + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + torch.cuda.empty_cache() - fx_frames = self.collect_frames(augmented_snapshot) - self.assertGreater(len(fx_frames), 2) - - for frame in fx_frames: - # Every FX frame should have both node_op and node_name - self.assertIn("fx_node_op", frame) - self.assertIn("fx_node_name", frame) - self.assertIn("fx_node_target", frame) - self.assertIn("fx_original_trace", frame) - - self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) - fx_node_name = frame["fx_node_name"] - if fx_node_name == "addmm": - self.assertIn("a = self.net1(x)", frame["fx_original_trace"]) - elif fx_node_name == "addmm_1": - self.assertIn("c = self.net2(b)", frame["fx_original_trace"]) - elif fx_node_name == "relu": - self.assertIn("b = self.relu(a)", frame["fx_original_trace"]) + fx_frames = self.collect_frames(augmented_snapshot) + self.assertGreater(len(fx_frames), 2) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("a = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("c = self.net2(b)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("b = self.relu(a)", frame["fx_original_trace"]) # Test that when we have two graphs with the same src_code, they're not hashed # to the same metadata - class MLPModule2(nn.Module): - def __init__(self, device): - super().__init__() - torch.manual_seed(5) - self.net1 = nn.Linear(10, 16, bias=True, device=device) - self.relu = nn.ReLU() - self.net2 = nn.Linear(16, 10, bias=True, device=device) - - def forward(self, x): - d = self.net1(x) - e = self.relu(d) - f = self.net2(e) - return f - - mod = MLPModule2(device) - with tempfile.TemporaryDirectory() as tmpdir: - torch.cuda.memory._record_memory_history() - compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) - result = compiled(torch.randn(10, 10, device=device)) - augmented_snapshot = torch.cuda.memory._snapshot( - augment_with_fx_traces=True - ) - torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) - - # avoid collecting segments from previous run for unit test purpose - fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False) - self.assertGreater(len(fx_frames), 0) - - for frame in fx_frames: - # Every FX frame should have both node_op and node_name - self.assertIn("fx_node_op", frame) - self.assertIn("fx_node_name", frame) - self.assertIn("fx_node_target", frame) - self.assertIn("fx_original_trace", frame) - - self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) - fx_node_name = frame["fx_node_name"] - if fx_node_name == "addmm": - self.assertIn("d = self.net1(x)", frame["fx_original_trace"]) - elif fx_node_name == "addmm_1": - self.assertIn("f = self.net2(e)", frame["fx_original_trace"]) - elif fx_node_name == "relu": - self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) + mod = self.MLPModule2(device) + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot(augment_with_fx_traces=True) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + + # avoid collecting segments from previous run for unit test purpose + fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False) + self.assertGreater(len(fx_frames), 0) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("d = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("f = self.net2(e)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) instantiate_parametrized_tests(TestCuda) From cdc8460f2c76f98ba30556e3f9358e857a2f22f0 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Mon, 10 Nov 2025 20:20:51 +0000 Subject: [PATCH 292/651] Use c7i.2xlarge for H100 build (#167466) The build system maybe oversized for what is necessary. Reduce the size to optimize costs. The default workflow runner is linux.c7i.2xlarge so we are just removing the runner definition in the workflow so that it uses the default. Relates to pytorch/test-infra#7175. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167466 Approved by: https://github.com/seemethere --- .github/workflows/h100-distributed.yml | 1 - .github/workflows/test-h100.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/h100-distributed.yml b/.github/workflows/h100-distributed.yml index be19b8f961f4d..c05b61e30a635 100644 --- a/.github/workflows/h100-distributed.yml +++ b/.github/workflows/h100-distributed.yml @@ -37,7 +37,6 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: "linux.c7i.12xlarge" build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index ec99f4473bb0b..510473d5306ad 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -41,7 +41,6 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '9.0' From 6cf21fa331b0b47e7dee9d128887b0021633b429 Mon Sep 17 00:00:00 2001 From: Robert Hardwick Date: Wed, 5 Nov 2025 09:56:55 +0000 Subject: [PATCH 293/651] Fix -ffunction-sections, -fdata-sections not being added on aarch64. (#166407) Preferred solution to #166380 Changes: - Moved summary print to bottom of CMakeLists.txt - Fix the problem 'add_compile_options' should be called before targets defined, so opted for `append_cxx_flag_if_supported` and `append_c_flag_if_supported` ( new ). - Added extra verbosity so it can be seen when linker script added. ( unfortunately linker script has to be added per-target rather than globally due to ninja/cmake depdendency tracking ). Also move summary print to bottom of CMakeLists.txt and improve logging Pull Request resolved: https://github.com/pytorch/pytorch/pull/166407 Approved by: https://github.com/Aidyn-A, https://github.com/atalman --- CMakeLists.txt | 96 +++++++++++++++++----------------------- cmake/public/utils.cmake | 19 ++++++++ 2 files changed, 59 insertions(+), 56 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 86f43f58817ba..f1d391ab6dbf9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -736,6 +736,44 @@ if(NOT DEFINED USE_BLAS) set(USE_BLAS ON) endif() +# Prioritized Text Linker Optimization +if(USE_PRIORITIZED_TEXT_FOR_LD) + + set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt") + set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld") + + execute_process( + COMMAND ${Python_EXECUTABLE} + ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py + --filein "${LINKER_SCRIPT_FILE_IN}" + --fout "${LINKER_SCRIPT_FILE_OUT}" + RESULT_VARIABLE _gen_result + OUTPUT_VARIABLE _gen_output + ERROR_VARIABLE _gen_error + ) + + if(NOT _gen_result EQUAL 0) + message(FATAL_ERROR + "Failed to generate linker script:\n${_gen_output}\n${_gen_error}") + endif() + + append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS) + append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS) + append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS) + append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS) + + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}") + set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}") + +else() + if(LINUX AND CPU_AARCH64) + message(WARNING [[ + It is strongly recommend to enable linker script optimization for all AArch64 Linux builds. + To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 + ]]) + endif() +endif() + # Build libtorch mobile library, which contains ATen/TH ops and native support # for TorchScript model, but doesn't contain not-yet-unified caffe2 ops; if(INTERN_BUILD_MOBILE) @@ -1402,9 +1440,6 @@ if(BUILD_JNI) add_subdirectory(android/pytorch_android) endif() -include(cmake/Summary.cmake) -caffe2_print_configuration_summary() - # Parse custom debug info if(DEFINED USE_CUSTOM_DEBINFO) string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}") @@ -1444,56 +1479,5 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA) DESTINATION "${CMAKE_INSTALL_BINDIR}") endif() -if(USE_PRIORITIZED_TEXT_FOR_LD) - add_compile_options( - $<$:-ffunction-sections> - $<$:-fdata-sections> - ) - set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld") - set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt") - - add_custom_command( - OUTPUT "${LINKER_SCRIPT_FILE_OUT}" - COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}" - DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}" - COMMENT "Generating prioritized text linker files" - VERBATIM - ) - - add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}") - - if(BUILD_PYTHON) - set(LINKER_OPT_TARGETS torch_python) - endif() - - if(NOT BUILD_LIBTORCHLESS) - list(APPEND LINKER_OPT_TARGETS torch_cpu c10) - if(USE_CUDA) - list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda) - endif() - if(USE_XPU) - list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu) - endif() - if(USE_ROCM) - list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip) - endif() - endif() - - foreach(tgt IN LISTS LINKER_OPT_TARGETS) - if(TARGET ${tgt}) - add_dependencies("${tgt}" generate_linker_script) - target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}") - set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}") - else() - message(WARNING "Requested target '${tgt}' for linker script optimization was not found.") - endif() - endforeach() - -else() - if(LINUX AND CPU_AARCH64) - message(WARNING [[ - It is strongly recommend to enable linker script optimization for all AArch64 Linux builds. - To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1 - ]]) - endif() -endif() +include(cmake/Summary.cmake) +caffe2_print_configuration_summary() diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index efc39f2bc1481..2cea7da5af3f0 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -478,6 +478,7 @@ function(torch_update_find_cuda_flags) endfunction() include(CheckCXXCompilerFlag) +include(CheckCCompilerFlag) include(CheckLinkerFlag) ############################################################################## @@ -501,6 +502,24 @@ function(append_cxx_flag_if_supported flag outputvar) endif() endfunction() +function(append_c_flag_if_supported flag outputvar) + string(TOUPPER "HAS${flag}" _FLAG_NAME) + string(REGEX REPLACE "[=-]" "_" _FLAG_NAME "${_FLAG_NAME}") + + # GCC silences unknown -Wno-XXX flags, so test the corresponding -WXXX. + if(CMAKE_C_COMPILER_ID STREQUAL "GNU") + string(REGEX REPLACE "^Wno-" "W" new_flag "${flag}") + else() + set(new_flag "${flag}") + endif() + + check_c_compiler_flag("${new_flag}" ${_FLAG_NAME}) + if(${_FLAG_NAME}) + string(APPEND ${outputvar} " ${flag}") + set(${outputvar} "${${outputvar}}" PARENT_SCOPE) + endif() +endfunction() + function(target_compile_options_if_supported target flag) set(_compile_options "") append_cxx_flag_if_supported("${flag}" _compile_options) From 0c2f206dedc3341c672bcebc2032ed213de8b564 Mon Sep 17 00:00:00 2001 From: Sean McGovern Date: Mon, 10 Nov 2025 20:35:39 +0000 Subject: [PATCH 294/651] Typo fix - baddbmm_strategy (#166963) This is called by registration with decorator, so function not called directly. For clarity, add the "b" for "batch" in function name. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166963 Approved by: https://github.com/janeyx99 --- torch/distributed/tensor/_ops/_matrix_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index cd6ba48d9832b..49152a1bee13a 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -256,7 +256,7 @@ def bmm_strategy(op_schema: OpSchema) -> OpStrategy: @register_op_strategy(aten.baddbmm.default) -def baddmm_strategy(op_schema: OpSchema) -> OpStrategy: +def baddbmm_strategy(op_schema: OpSchema) -> OpStrategy: mesh = op_schema.get_mesh_from_args() return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) From 3e4faca13034f8468e42dda7000289c840a62ded Mon Sep 17 00:00:00 2001 From: Malay Bag Date: Mon, 10 Nov 2025 20:44:18 +0000 Subject: [PATCH 295/651] [torch.export] Refactor placeholder_naming_pass to reduce CCN (#166600) Summary: Reduced CCN from 37 to 28 of placeholder_naming_pass method Test Plan: Existing tests Differential Revision: D85820388 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166600 Approved by: https://github.com/angelayi --- torch/_export/utils.py | 61 +++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 74230e4a5ed55..3828dc97ac9bc 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -974,6 +974,41 @@ def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: subgraph.recompile() +def _assign_new_node_names( + gm: torch.fx.GraphModule, + name_map: dict[str, str], + custom_meta: dict[str, Any], +) -> None: + """ + Assign new names to all nodes, in the graph module, from name map. + """ + for node in gm.graph.nodes: + if node.op == "placeholder": + assert node.name in name_map + node.name = node.target = name_map[node.name] + if node.name in custom_meta: + if node.meta.get("custom") is None: + node.meta["custom"] = {} + else: + # Assert if any existing key has different value + for k, v in node.meta["custom"].items(): + if ( + k in custom_meta[node.name] + and v != custom_meta[node.name][k] + ): + raise AssertionError( + f"Mismatch in custom metadata for key {k}. Value in " + f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}." + ) + node.meta["custom"].update(custom_meta[node.name]) + # if the constant obj is an input, we also need to update meta["val"] + # because this is created before the placeholder naming pass + if isinstance(node.meta["val"], CustomObjArgument): + node.meta["val"].name = node.name + elif node.name in name_map: + node.name = name_map[node.name] + + def placeholder_naming_pass( gm: torch.fx.GraphModule, export_graph_signature: "ExportGraphSignature", @@ -1091,31 +1126,7 @@ def _extract_pytree_key(x): ) # assign new node names - for node in gm.graph.nodes: - if node.op == "placeholder": - assert node.name in name_map - node.name = node.target = name_map[node.name] - if node.name in custom_meta: - if node.meta.get("custom") is None: - node.meta["custom"] = {} - else: - # Assert if any existing key has different value - for k, v in node.meta["custom"].items(): - if ( - k in custom_meta[node.name] - and v != custom_meta[node.name][k] - ): - raise AssertionError( - f"Mismatch in custom metadata for key {k}. Value in " - f"node.meta is {v} and value in custom_meta is {custom_meta[node.name][k]}." - ) - node.meta["custom"].update(custom_meta[node.name]) - # if the constant obj is an input, we also need to update meta["val"] - # because this is created before the placeholder naming pass - if isinstance(node.meta["val"], CustomObjArgument): - node.meta["val"].name = node.name - elif node.name in name_map: - node.name = name_map[node.name] + _assign_new_node_names(gm, name_map, custom_meta) # propagate names to higher order op subgraphs _name_hoo_subgraph_placeholders(gm) From 5320ca3725c4ccf2811c211b48af1ddebb2b471f Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 18:34:27 -0800 Subject: [PATCH 296/651] [inductor, 3.14] fix itertools.product pickle error in test_cpu_repro (#167382) `inductor/test_cpu_cpp_wrapper` was failing since it was attempting to pickle`itertools.product`, and that is no longer picklable in 3.14. We work around by eagerly generating a list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167382 Approved by: https://github.com/atalman, https://github.com/malfet --- test/inductor/test_cpu_repro.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index cf4900c8536bf..cb9aa263a0c9a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -726,8 +726,7 @@ def test_lstm_packed( seq_len, ) - @parametrize( - "unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len", + _test_lstm_packed_change_input_sizes_cpu_params = list( itertools.product( *[ [False], @@ -741,7 +740,12 @@ def test_lstm_packed( [2], [3], ] - ), + ) + ) + + @parametrize( + "unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len", + _test_lstm_packed_change_input_sizes_cpu_params, ) def test_lstm_packed_change_input_sizes_cpu( self, From ad7db3617ec5cc3aa384bd4408fcfbc2acac1a98 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 18:34:28 -0800 Subject: [PATCH 297/651] [inductor, 3.14] catch pickle.PicklingError exceptions (#167383) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167383 Approved by: https://github.com/aorenste ghstack dependencies: #167382 --- test/test_serialization.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/compile_fx_ext.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 2755ae29a7ffa..39f8b7735663f 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4506,7 +4506,7 @@ def fn(t): exc = pickle.PicklingError if sys.version_info >= (3, 14) else AttributeError with self.assertRaisesRegex( exc, - "Can't (get|pickle) local object (.remove" + r"Can't (get|pickle) local object (\.remove" ): with skip_data(), BytesIOContext() as f: torch.save(ft, f) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index b0bea9d2d6bb9..5545115840953 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -624,7 +624,7 @@ def dumps(self, obj: Any) -> bytes: try: self.dump(obj) return self._stream.getvalue() - except (TypeError, AttributeError) as e: + except (TypeError, AttributeError, pickle.PicklingError) as e: # Some configs options may not pickle. log.warning("Failed to pickle cache key", exc_info=True) raise BypassFxGraphCache("Failed to pickle cache key") from e diff --git a/torch/_inductor/compile_fx_ext.py b/torch/_inductor/compile_fx_ext.py index f02939225c462..24048ccdda12c 100644 --- a/torch/_inductor/compile_fx_ext.py +++ b/torch/_inductor/compile_fx_ext.py @@ -468,6 +468,8 @@ def serialize_compile( fake_mode = _current_fake_mode() fake_tensor_mode = _FakeTensorModeSerializer(fake_mode) + from pickle import PicklingError + try: input = _WireProtocolInput( gm, @@ -483,7 +485,7 @@ def serialize_compile( fake_tensor_mode, ).serialize() return (input, constants) - except (AttributeError, BypassFxGraphCache): + except (AttributeError, BypassFxGraphCache, PicklingError): # For example: AttributeError: Can't pickle local object # 'make_opaque_unary_fn..OpaqueUnaryFn' From 17e70ae459c45d85ef77afa4d19efe5f8b44f573 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 18:34:28 -0800 Subject: [PATCH 298/651] [dynamo, 3.14] enable dynamo in 3.14 (#167384) dynamo tests are passing in the CI PR above - so we could probably just enable dynamo right now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167384 Approved by: https://github.com/Skylion007, https://github.com/mlazos ghstack dependencies: #167382, #167383 --- tools/dynamo/verify_dynamo.py | 4 ++-- torch/__init__.py | 4 ++-- torch/_dynamo/eval_frame.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py index a8ce085e864ea..5b3444a79f112 100644 --- a/tools/dynamo/verify_dynamo.py +++ b/tools/dynamo/verify_dynamo.py @@ -216,8 +216,8 @@ def main() -> None: f"ROCM version: {rocm_ver}\n" ) for args in _SANITY_CHECK_ARGS: - if sys.version_info >= (3, 14): - warnings.warn("Dynamo not yet supported in Python 3.14. Skipping check.") + if sys.version_info >= (3, 15): + warnings.warn("Dynamo not yet supported in Python 3.15.") check_dynamo(*args) print("All required checks passed") diff --git a/torch/__init__.py b/torch/__init__.py index 6ce2549964abb..ba9d84b29bce2 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2620,8 +2620,8 @@ def foo(x): import sysconfig _C._log_api_usage_once("torch.compile") - if sys.version_info >= (3, 14): - raise RuntimeError("torch.compile is not supported on Python 3.14+") + if sys.version_info >= (3, 15): + raise RuntimeError("torch.compile is not supported on Python 3.15+") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( 3, 13, diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 9b9572620db14..0956facde2559 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1316,8 +1316,8 @@ def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec: def check_if_dynamo_supported() -> None: - if sys.version_info >= (3, 14): - raise RuntimeError("Python 3.14+ not yet supported for torch.compile") + if sys.version_info >= (3, 15): + raise RuntimeError("Python 3.15+ not yet supported for torch.compile") elif sysconfig.get_config_var("Py_GIL_DISABLED") == 1 and sys.version_info < ( 3, 13, From cf63b212e330836c2be92bef903f5a5d0dc2c7e9 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 18:34:29 -0800 Subject: [PATCH 299/651] [3.14, dataloader] handle forkserver default mp start method in 3.14 (#167387) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167387 Approved by: https://github.com/malfet ghstack dependencies: #167382, #167383, #167384 --- test/test_dataloader.py | 21 +++++++++++++++++++++ torch/utils/data/dataloader.py | 15 ++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index ba3fe63ed1f1b..5aeb7222e8895 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2015,6 +2015,22 @@ def test_worker_init_fn(self): self.assertEqual(12345, batch[0]) self.assertEqual(12345, batch[1]) + def test_worker_init_fn_forkserver(self): + def local_init_fn(worker_id): + torch.manual_seed(12345) + + import multiprocessing as py_mp + + py_mp.set_start_method("forkserver", force=True) + + dataset = SeedDataset(4) + dataloader = self._get_data_loader( + dataset, batch_size=2, num_workers=2, worker_init_fn=local_init_fn + ) + with self.assertWarnsRegex(UserWarning, "Got pickle error when"): + with self.assertRaises(Exception): + next(iter(dataloader)) + def test_get_worker_info(self): p = ErrorTrackingProcess(target=_test_get_worker_info) p.start() @@ -3524,6 +3540,11 @@ def worker_set_affinity(_): dataset = SetAffinityDataset() + if not IS_WINDOWS and not IS_MACOS: + import multiprocessing as py_mp + + py_mp.set_start_method("fork", force=True) + dataloader = torch.utils.data.DataLoader( dataset, num_workers=2, worker_init_fn=worker_set_affinity ) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 1f8f0d70c9c2f..e01422708f791 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1185,7 +1185,20 @@ def __init__(self, loader) -> None: # it started, so that we do not call .join() if program dies # before it starts, and __del__ tries to join but will get: # AssertionError: can only join a started process. - w.start() + from pickle import PicklingError + + try: + w.start() + except (TypeError, AttributeError, PicklingError): + warnings.warn( + "Got pickle error when attempting to start a worker Process. " + "This might be because the worker Process arguments are not picklable. " + "Python 3.14+ changed the multiprocessing start method in non-Mac POSIX platforms " + "to 'forkserver', which requires the worker Process arguments to be picklable. " + "You can also try multiprocessing.set_start_method('fork').", + stacklevel=2, + ) + raise self._index_queues.append(index_queue) self._workers.append(w) From fe0bb7cf6001532b14bba14d686baa1ff0b98de0 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 7 Nov 2025 18:34:30 -0800 Subject: [PATCH 300/651] [export, 3.14] handle patching methods with functools.partial correctly in non-strict export (#167396) Note: dynamo is not affected by this since patching class methods are not supported right now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167396 Approved by: https://github.com/angelayi ghstack dependencies: #167382, #167383, #167384, #167387 --- test/export/test_export.py | 12 ++++++++++-- torch/_export/non_strict_utils.py | 9 +++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index a2fd76e0e0ccc..bf84827289ab2 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -11,6 +11,7 @@ import operator import os import re +import sys import traceback import unittest import warnings @@ -12236,8 +12237,15 @@ class Foo(torch.nn.Module): def forward(self, x): return x + 2 - def fancy_forward(x, y): - return x + 2 + y + if sys.version_info >= (3, 14): + # functools.partial is now a method descriptor: + # https://docs.python.org/3/whatsnew/3.14.html#changes-in-the-python-api + def fancy_forward(self, x, y): + return x + 2 + y + else: + + def fancy_forward(x, y): + return x + 2 + y Foo.forward = functools.partial(fancy_forward, y=torch.randn(4, 4)) x = torch.randn(4, 4) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index ef510480347c8..602c8fa61df73 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -5,6 +5,7 @@ import inspect import logging import math +import sys from collections import defaultdict from collections.abc import Callable, Sequence from contextlib import contextmanager @@ -421,6 +422,14 @@ def make_fake_inputs( if isinstance(nn_module.forward, functools.partial): # functools handles nesting by itself, no need to recurse code = nn_module.forward.func.__code__ + elif ( + sys.version_info >= (3, 14) + and (fwd := getattr(nn_module.forward, "__func__", None)) + and isinstance(fwd, functools.partial) + ): + # functools.partial is now a method descriptor: + # https://docs.python.org/3/whatsnew/3.14.html#changes-in-the-python-api + code = fwd.func.__code__ else: code = nn_module.forward.__code__ co_fields = { From 2751b1d3c32c89d909363ab8a3bd43ad05ebe97e Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Mon, 10 Nov 2025 07:34:49 -0800 Subject: [PATCH 301/651] Support repr on user defined objects (#167372) Fixes: https://github.com/pytorch/pytorch/issues/167369 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167372 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 14 ++++++++++++++ torch/_dynamo/variables/builtin.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 0db7043b02c21..ca146760584a8 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -5788,6 +5788,20 @@ def test_cross_entropy_loss_simple_ctor(self): self.assertTrue(torch.allclose(dynamo_output, output)) + def test_repr(self): + class Config: + def __repr__(self): + return "Config()" + + def forward(x, config): + return x * len(repr(config)) + + config = Config() + x = torch.randn(2, 2) + + compiled = torch.compile(forward, fullgraph=True) + compiled(x, config) + def test_nn_functional_reduction(self): def fn(loss, reduction): reduction_enum = F._Reduction.get_enum(reduction) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index f1d43b6d48995..ffffb6f48b17b 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1608,6 +1608,21 @@ def call_bool( # TODO handle more cases and merge this with this with `generic_jump`. return None + def call_repr(self, tx: "InstructionTranslator", arg): + """Handle repr() on user defined objects.""" + if isinstance(arg, variables.UserDefinedObjectVariable): + repr_method = arg.value.__repr__ + + if type(arg.value).__repr__ is object.__repr__: + # Default repr - build and trace it + fn_vt = VariableTracker.build(tx, repr_method) + return fn_vt.call_function(tx, [], {}) + else: + # Custom repr - inline the method for tracing + bound_method = repr_method.__func__ + fn_vt = VariableTracker.build(tx, bound_method) + return fn_vt.call_function(tx, [arg], {}) + def call_str( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker | None: From bb3748346484d49ace45dcc92b72c12b2ba30d98 Mon Sep 17 00:00:00 2001 From: Thanh Ha Date: Mon, 10 Nov 2025 21:45:42 +0000 Subject: [PATCH 302/651] Use c7i.2xlarge for B200 build (#167078) The build system is oversized for what is necessary. Reduce the size to optimize costs. The default workflow runner is `linux.c7i.2xlarge` so we are just removing the runner definition in the workflow so that it uses the default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167078 Approved by: https://github.com/nWEIdia, https://github.com/seemethere --- .github/workflows/b200-distributed.yml | 1 - .github/workflows/b200-symm-mem.yml | 1 - .github/workflows/test-b200.yml | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/b200-distributed.yml b/.github/workflows/b200-distributed.yml index 596a31431e61b..899df8107ff35 100644 --- a/.github/workflows/b200-distributed.yml +++ b/.github/workflows/b200-distributed.yml @@ -37,7 +37,6 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/b200-symm-mem.yml b/.github/workflows/b200-symm-mem.yml index 7fa8a8a730447..f0d20c270ed3d 100644 --- a/.github/workflows/b200-symm-mem.yml +++ b/.github/workflows/b200-symm-mem.yml @@ -37,7 +37,6 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index ef7f75bc4b2b4..6ba403b9e12ac 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -52,7 +52,6 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.12xlarge.memory build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' @@ -73,4 +72,4 @@ jobs: docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }} aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only - secrets: inherit \ No newline at end of file + secrets: inherit From 6ca8cc6edf30b5ca882d4871af617e674b6cdd47 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 10 Nov 2025 21:47:49 +0000 Subject: [PATCH 303/651] Rework PyObject preservation (#166342) Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is: * Python Tensor and Storage objects always hold a strong reference to their underlying c10 object * c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object This is implemented in `intrusive_ptr`: * The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits. * When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive. * When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle. Other benefits: * We can delete a lot of the copypasta from Python internal `subtype_dealloc` * This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now. Risks: * Extra branch for reference count operations on `intrusive_ptr`, `intrusive_ptr`, and the generic `intrusive_ptr` even when we're not using Python. * It's a big change Pull Request resolved: https://github.com/pytorch/pytorch/pull/166342 Approved by: https://github.com/albanD --- aten/src/ATen/core/TensorBase.h | 3 + aten/tools/valgrind.sup | 7 + c10/core/SafePyObject.h | 4 +- c10/core/StorageImpl.cpp | 24 + c10/core/StorageImpl.h | 16 + c10/core/TensorImpl.cpp | 25 +- c10/core/TensorImpl.h | 17 + c10/core/impl/PyInterpreter.cpp | 11 +- c10/core/impl/PyInterpreter.h | 12 +- c10/core/impl/PyObjectSlot.cpp | 56 -- c10/core/impl/PyObjectSlot.h | 143 ++-- c10/util/intrusive_ptr.h | 143 +++- test/test_autograd.py | 28 + test/test_torch.py | 56 +- torch/csrc/Module.cpp | 42 +- torch/csrc/PyInterpreter.cpp | 77 +- torch/csrc/Storage.cpp | 327 ++------ torch/csrc/Storage.h | 8 +- torch/csrc/StorageMethods.cpp | 5 +- .../csrc/autograd/functions/accumulate_grad.h | 6 +- torch/csrc/autograd/input_buffer.cpp | 4 +- torch/csrc/autograd/python_variable.cpp | 769 ++++-------------- torch/csrc/autograd/python_variable.h | 5 +- .../autograd/utils/grad_layout_contract.h | 4 +- torch/csrc/autograd/utils/wrap_outputs.h | 4 + torch/csrc/autograd/variable.h | 18 +- torch/csrc/utils/pyobject_preservation.cpp | 76 +- torch/csrc/utils/pyobject_preservation.h | 26 +- 28 files changed, 724 insertions(+), 1192 deletions(-) delete mode 100644 c10/core/impl/PyObjectSlot.cpp diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 2b9558197bdcb..2d7ca10433d6a 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -245,6 +245,9 @@ class TORCH_API TensorBase { size_t weak_use_count() const noexcept { return impl_.weak_use_count(); } + bool is_uniquely_owned() const noexcept { + return impl_.is_uniquely_owned(); + } std::string toString() const; diff --git a/aten/tools/valgrind.sup b/aten/tools/valgrind.sup index ad5f66e0b0531..585487c4d2be2 100644 --- a/aten/tools/valgrind.sup +++ b/aten/tools/valgrind.sup @@ -10,6 +10,13 @@ ... } +{ + ignore_empty_generic_uninitialised_conditional_jump + Memcheck:Cond + fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE + ... +} + { Cond_cuda Memcheck:Cond diff --git a/c10/core/SafePyObject.h b/c10/core/SafePyObject.h index 1ec0cdb6751e9..bcace0ac358b4 100644 --- a/c10/core/SafePyObject.h +++ b/c10/core/SafePyObject.h @@ -44,7 +44,7 @@ struct C10_API SafePyObject { (*other.pyinterpreter_)->incref(other.data_); } if (data_ != nullptr) { - (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false); + (*pyinterpreter_)->decref(data_); } data_ = other.data_; pyinterpreter_ = other.pyinterpreter_; @@ -53,7 +53,7 @@ struct C10_API SafePyObject { ~SafePyObject() { if (data_ != nullptr) { - (*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false); + (*pyinterpreter_)->decref(data_); } } diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index a614fc9234c94..00fc03bbd0fcf 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } +void StorageImpl::incref_pyobject() const { + // Because intrusive_ptr incref uses relaxed memory order, we need to + // do an acquire fence to ensure that the kHasPyObject bit was + // observed before the load of the PyObject* below. + // NB: This is a no-op on x86/x86-64 + std::atomic_thread_fence(std::memory_order_acquire); + + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->incref(obj); +} + +void StorageImpl::decref_pyobject() const { + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->decref(obj); +} + +bool StorageImpl::try_incref_pyobject() const { + c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); + if (C10_UNLIKELY(!interp)) { + return false; + } + return (*interp)->try_incref(pyobj_slot_); +} + void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { // Allowlist verification. // Only if the devicetype is in the allowlist, diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index f34a1baed7a48..c471992ac1bb9 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { data_ptr_.clear(); } + void incref_pyobject() const override final; + + void decref_pyobject() const override final; + + bool try_incref_pyobject() const override final; + size_t nbytes() const { // OK to do this instead of maybe_as_int as nbytes is guaranteed positive TORCH_CHECK(!size_bytes_is_heap_allocated_); @@ -370,4 +376,14 @@ C10_API c10::intrusive_ptr make_storage_impl( bool resizable, std::optional device_opt); +namespace detail { +template +struct TargetTraits< + T, + std::enable_if_t< + std::is_base_of_v>>> { + static constexpr bool can_have_pyobject = true; +}; +} // namespace detail + } // namespace c10 diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index c59524a0932c2..94a7375cc32fb 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -277,7 +277,6 @@ void TensorImpl::release_resources() { if (storage_) { storage_ = {}; } - pyobj_slot_.maybe_destroy_pyobj(); } #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY @@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } +void TensorImpl::incref_pyobject() const { + // Because intrusive_ptr incref uses relaxed memory order, we need to + // do an acquire fence to ensure that the kHasPyObject bit was + // observed before the load of the PyObject* below. + // NB: This is a no-op on x86/x86-64 + std::atomic_thread_fence(std::memory_order_acquire); + + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->incref(obj); +} + +void TensorImpl::decref_pyobject() const { + PyObject* obj = pyobj_slot_.load_pyobj(); + (*pyobj_slot_.pyobj_interpreter())->decref(obj); +} + +bool TensorImpl::try_incref_pyobject() const { + c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); + if (C10_UNLIKELY(!interp)) { + return false; + } + return (*interp)->try_incref(pyobj_slot_); +} + namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 66893b86c8469..4b1df95213849 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -2176,6 +2176,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } + void incref_pyobject() const override final; + + void decref_pyobject() const override final; + + bool try_incref_pyobject() const override final; + private: // See NOTE [std::optional operator usage in CUDA] // We probably don't want to expose this publicly until @@ -3077,6 +3083,17 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { friend class C10_TensorImpl_Size_Check_Dummy_Class; }; +namespace detail { + +template +struct TargetTraits< + T, + std::enable_if_t>>> { + static constexpr bool can_have_pyobject = true; +}; + +} // namespace detail + // Note [TensorImpl size constraints] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Changed the size of TensorImpl? If the size went down, good for diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 8676f0aaf8e0e..52d263fad36c5 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { void incref(PyObject* pyobj) const override {} // do nothing - void decref(PyObject* pyobj, bool has_pyobj_slot) const override { - } // do nothing + void decref(PyObject* pyobj) const override {} // do nothing + + bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override { + return false; + } #define PANIC(m) \ TORCH_INTERNAL_ASSERT( \ @@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { "attempted to call " #m \ " on a Tensor with nontrivial PyObject after corresponding interpreter died") + size_t refcnt(PyObject* pyobj) const override { + PANIC(refcnt); + } + c10::intrusive_ptr detach(const TensorImpl* self) const override { PANIC(detach); } diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index def708c24b802..463b1e520b36e 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -18,6 +18,9 @@ namespace c10 { struct IValue; class OperatorHandle; struct TensorImpl; +namespace impl { +struct PyObjectSlot; +} // namespace impl } // namespace c10 namespace torch::jit { @@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable { // Run Py_INCREF on a PyObject. virtual void incref(PyObject* pyobj) const = 0; - // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call - // See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] - virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0; + // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call. + virtual void decref(PyObject* pyobj) const = 0; + // Run PyUnstable_TryIncRef on a PyObject if it's not NULL. + virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0; + // Run Py_REFCNT on a PyObject. + virtual size_t refcnt(PyObject* pyobj) const = 0; // Perform a detach by deferring to the __torch_dispatch__ implementation of // detach, which will also arrange for the PyObject to get copied in this diff --git a/c10/core/impl/PyObjectSlot.cpp b/c10/core/impl/PyObjectSlot.cpp deleted file mode 100644 index 0f1bfb2110747..0000000000000 --- a/c10/core/impl/PyObjectSlot.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include - -namespace c10::impl { - -PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} - -PyObjectSlot::~PyObjectSlot() { - maybe_destroy_pyobj(); -} - -void PyObjectSlot::maybe_destroy_pyobj() { - if (owns_pyobj()) { - TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr); - TORCH_INTERNAL_ASSERT(pyobj_ != nullptr); - (*pyobj_interpreter_.load(std::memory_order_acquire)) - ->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true); - // NB: this destructor can only be entered when there are no - // references to this C++ object (obviously), NOR any references - // to the PyObject (if there are references to the PyObject, - // then the PyObject holds an owning reference to the tensor). - // So it is OK to clear pyobj_ here as it is impossible for it to - // be used again (modulo weak reference races) - pyobj_ = nullptr; // for safety - } -} - -PyInterpreter* PyObjectSlot::pyobj_interpreter() { - return pyobj_interpreter_.load(std::memory_order_acquire); -} - -PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast( - reinterpret_cast(pyobj_) & ~0x1ULL); -} - -PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const { - auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); - if (interpreter) { - return *interpreter; - } - TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set"); -} - -bool PyObjectSlot::owns_pyobj() { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - return reinterpret_cast(pyobj_) & 1; -} - -void PyObjectSlot::set_owns_pyobj(bool b) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - pyobj_ = reinterpret_cast( - reinterpret_cast(_unchecked_untagged_pyobj()) | b); -} - -} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 58b2490eba001..2d333b0fe503f 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -8,117 +8,70 @@ #include +namespace torch::utils { +class PyObjectPreservation; +} + namespace c10::impl { struct C10_API PyObjectSlot { public: - PyObjectSlot(); - - ~PyObjectSlot(); + PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {} - void maybe_destroy_pyobj(); + // Query the PyObject interpreter. This may return null if there is no + // interpreter. + PyInterpreter* pyobj_interpreter() const { + return pyobj_interpreter_.load(std::memory_order_acquire); + } - // Associate the TensorImpl with the specified PyObject, and, if necessary, - // also tag the interpreter. - // - // NB: This lives in a header so that we can inline away the switch on status - // - // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after - // PyObject if necessary! - void init_pyobj(PyObject* pyobj) { - pyobj_interpreter_.store( - getGlobalPyInterpreter(), std::memory_order_relaxed); - pyobj_ = pyobj; + PyInterpreter& load_pyobj_interpreter() const { + auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire); + TORCH_INTERNAL_ASSERT( + interpreter, "cannot access PyObject for Tensor - no interpreter set"); + return *interpreter; } - // Query the PyObject interpreter. This may return null if there is no - // interpreter. This is racy! - PyInterpreter* pyobj_interpreter(); - - PyObject* _unchecked_untagged_pyobj() const; - - // Test the interpreter tag. If tagged for the current interpreter, return - // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, - // returns a nullopt. If it is definitely invalid, raises an error. - // - // If `ignore_hermetic_tls` is false and this function is called from a - // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then - // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic - // context is ignored, allowing you to check the interpreter tag of a - // nonhermetic PyObject from within a hermetic context. This is necessary - // because there are some cases where the deallocator function of a - // nonhermetic PyObject is called from within a hermetic context, so it must - // be properly treated as a nonhermetic PyObject. - // - // NB: this lives in header so that we can avoid actually creating the - // std::optional - - // @todo alban: I'm not too sure what's going on here, we can probably delete - // it but it's worthwhile making sure - std::optional check_pyobj(bool ignore_hermetic_tls = false) const { - impl::PyInterpreter* interpreter = - pyobj_interpreter_.load(std::memory_order_acquire); - if (interpreter == nullptr) { - return std::nullopt; - } - - if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) { - return std::nullopt; - } else { - return _unchecked_untagged_pyobj(); - } + PyObject* load_pyobj() const { + return pyobj_.load(std::memory_order_acquire); } - PyInterpreter& load_pyobj_interpreter() const; + bool has_unique_reference() const { + PyObject* pyobj = load_pyobj(); + return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1; + } - bool owns_pyobj(); + void clear() { + pyobj_.store(nullptr, std::memory_order_relaxed); + pyobj_interpreter_.store(nullptr, std::memory_order_relaxed); + } - void set_owns_pyobj(bool b); + // Non thread-safe swap + void swap(PyObjectSlot& other) noexcept { + PyInterpreter* tmp_interpreter = + pyobj_interpreter_.load(std::memory_order_relaxed); + pyobj_interpreter_.store( + other.pyobj_interpreter_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + other.pyobj_interpreter_.store(tmp_interpreter, std::memory_order_relaxed); + + PyObject* tmp_pyobj = pyobj_.load(std::memory_order_relaxed); + pyobj_.store( + other.pyobj_.load(std::memory_order_relaxed), + std::memory_order_relaxed); + other.pyobj_.store(tmp_pyobj, std::memory_order_relaxed); + } private: - // This field contains the interpreter tag for this object. See - // Note [Python interpreter tag] for general context - // - // Note [Memory ordering on Python interpreter tag] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // What memory_order do we need when accessing this atomic? We don't - // need a single total modification order (as provided by - // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only - // transition from -1 to some positive integer and never changes afterwards. - // Because there is only one modification, it trivially already has a total - // modification order (e.g., we don't need fences or locked instructions on - // x86) - // - // In fact, one could make a reasonable argument that relaxed reads are OK, - // due to the presence of external locking (GIL) to ensure that interactions - // with other data structures are still correctly synchronized, so that - // we fall in the "Single-Location Data Structures" case as described in - // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf - // However, on x86, it doesn't matter if I use acquire or relaxed on the load - // as I get the same assembly in both cases. So I just use the more - // conservative acquire (which will impede compiler optimizations but I don't - // care) + // This is now always the global interpreter if the PyObject is set. + // Maybe we can remove this field some day... std::atomic pyobj_interpreter_; - // This field contains a reference to a PyObject representing this Tensor. - // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new - // PyObject for it and set this field. This field does not have to be - // protected by an atomic as it is only allowed to be accessed when you hold - // the GIL, or during destruction of the tensor. - // - // When a PyObject dies, you are obligated to clear this field - // (otherwise, you will try to use-after-free the pyobj); this currently - // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp - // - // NB: Ordinarily, this should not be a strong reference, as if the - // PyObject owns the Tensor, this would create a reference cycle. - // However, sometimes this ownership flips. To track who owns - // who, this has a single pointer tag indicating whether or not the - // C++ object owns the PyObject (the common case, zero, means PyObject - // owns the C++ object); see _unchecked_untagged_pyobj for raw access - // or check_pyobj for checked access. See references to PyObject - // resurrection in torch/csrc/autograd/python_variable.cpp - PyObject* pyobj_; + // The PyObject representing this Tensor or nullptr. Ownership is managed + // by intrusive_ptr. By the time the PyObjectSlot is destroyed, this + // reference is already dead. + std::atomic pyobj_; + + friend class torch::utils::PyObjectPreservation; }; } // namespace c10::impl diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 3d5478be90e60..3a3a63d7c5090 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -12,6 +12,10 @@ template class class_; } +namespace torch::utils { +class PyObjectPreservation; +} + namespace c10 { class intrusive_ptr_target; namespace raw { @@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount = constexpr uint64_t kReferenceCountOne = 1; constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32); constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne); +// Indicates whether the object has a PyObject wrapper. +constexpr uint64_t kHasPyObject = (uint64_t(1) << 63); template struct intrusive_target_default_null_type final { @@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) { } inline uint32_t weakcount(uint64_t combined_refcount) { - return static_cast(combined_refcount >> 32); + return static_cast((combined_refcount & ~kHasPyObject) >> 32); +} + +inline bool has_pyobject(uint64_t combined_refcount) { + return (combined_refcount & kHasPyObject) != 0; } // The only requirement for refcount increment is that it happens-before @@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment( return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc; } -inline uint32_t atomic_refcount_increment( - std::atomic& combined_refcount) { - return detail::refcount(atomic_combined_refcount_increment( - combined_refcount, kReferenceCountOne)); -} - inline uint32_t atomic_weakcount_increment( std::atomic& combined_refcount) { return detail::weakcount(atomic_combined_refcount_increment( @@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement( combined_refcount, kWeakReferenceCountOne)); } +template +struct TargetTraits { + static constexpr bool can_have_pyobject = false; +}; + } // namespace detail /** @@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target { // we can atomically operate on both at the same time for performance // and defined behaviors. // + // Note [PyObject preservation for Tensor and Storages] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // intrusive_ptr has special support for preserving PyObject wrappers + // for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of + // the combined_refcount_ is used to indicate whether the object has a + // PyObject wrapper. + // + // - The PyObject, if it exists, holds a strong reference to the + // intrusive_ptr_target. + // + // - When the refcount goes from 1 to 2, we incref the PyObject. + // + // - When the refcount goes from 2 to 1, we decref the PyObject. + // + // In other words, the intrusive_ptr keeps the PyObject alive as long as there + // are other C++ references to the intrusive_ptr_target. + mutable std::atomic combined_refcount_; static_assert(sizeof(std::atomic) == 8); static_assert(alignof(std::atomic) == 8); @@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target { template friend struct ExclusivelyOwnedTensorTraits; + friend class torch::utils::PyObjectPreservation; + protected: // protected destructor. We never want to destruct intrusive_ptr_target* // directly. @@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target { */ virtual void release_resources() {} + /** + * These two methods are called when the refcount transitions between one + * and two and the object has a PyObject wrapper. + */ + virtual void incref_pyobject() const {} + virtual void decref_pyobject() const {} + virtual bool try_incref_pyobject() const { + return false; + } + uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const { return detail::refcount(combined_refcount_.load(order)); } @@ -265,6 +303,15 @@ class C10_API intrusive_ptr_target { } }; +namespace detail { +template <> +struct TargetTraits { + // A generic intrusive_ptr may actually be a TensorImpl + // or StorageImpl, so we have to allow for PyObject support. + static constexpr bool can_have_pyobject = true; +}; +} // namespace detail + template class weak_intrusive_ptr; @@ -314,18 +361,34 @@ class intrusive_ptr final { void retain_() { if (target_ != NullType::singleton()) { - uint32_t new_refcount = - detail::atomic_refcount_increment(target_->combined_refcount_); + uint64_t combined = detail::atomic_combined_refcount_increment( + target_->combined_refcount_, detail::kReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( new_refcount != 1, "intrusive_ptr: Cannot increase refcount after it reached zero."); + + if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 1 to 2, we need to incref the + // PyObject. In other words, we need to ensure that the PyObject stays + // alive now that we have a C++ reference to this object in addition to + // the PyObject itself. + if (C10_UNLIKELY( + detail::has_pyobject(combined) && + detail::refcount(combined) == 2)) { + target_->incref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !detail::has_pyobject(combined), + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); + } } } void reset_() noexcept { if (target_ != NullType::singleton()) { - if (target_->combined_refcount_.load(std::memory_order_acquire) == - detail::kUniqueRef) { + if (is_uniquely_owned()) { // Both counts are 1, so there are no weak references and // we are releasing the last strong reference. No other // threads can observe the effects of this target_ deletion @@ -337,9 +400,10 @@ class intrusive_ptr final { auto combined_refcount = detail::atomic_combined_refcount_decrement( target_->combined_refcount_, detail::kReferenceCountOne); - if (detail::refcount(combined_refcount) == 0) { - bool should_delete = - (combined_refcount == detail::kWeakReferenceCountOne); + uint32_t new_refcount = detail::refcount(combined_refcount); + bool has_pyobject = detail::has_pyobject(combined_refcount); + if (new_refcount == 0) { + bool should_delete = detail::weakcount(combined_refcount) == 1; // See comment above about weakcount. As long as refcount>0, // weakcount is one larger than the actual number of weak references. // So we need to decrement it here. @@ -356,6 +420,18 @@ class intrusive_ptr final { if (should_delete) { delete target_; } + } else if constexpr (detail::TargetTraits::can_have_pyobject) { + // If the refcount transitioned from 2 to 1, we need to decref the + // PyObject. In other words, we don't want to keep the PyObject alive if + // there are no C++ references to this object other than the PyObject + // itself. + if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) { + target_->decref_pyobject(); + } + } else { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !has_pyobject, + "TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set."); } } } @@ -522,6 +598,16 @@ class intrusive_ptr final { return use_count() == 1; } + /** + * Stronger than unique() in that it must not have any weakrefs as well. + */ + bool is_uniquely_owned() const noexcept { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton()); + uint64_t combined = + target_->combined_refcount_.load(std::memory_order_acquire); + return (combined & ~detail::kHasPyObject) == detail::kUniqueRef; + } + /** * Returns an owning (!) pointer to the underlying object and makes the * intrusive_ptr instance invalid. That means the refcount is not decreased. @@ -932,6 +1018,7 @@ class weak_intrusive_ptr final { if (target_ == NullType::singleton()) { return intrusive_ptr(); } else { + bool increfed = false; auto combined_refcount = target_->combined_refcount_.load(std::memory_order_relaxed); do { @@ -940,12 +1027,31 @@ class weak_intrusive_ptr final { // Return nullptr. return intrusive_ptr(); } + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (detail::has_pyobject(combined_refcount) && + detail::refcount(combined_refcount) == 1 && !increfed) { + // Object has a python wrapper with no other C++ references. + // We need to to incref the Python object before we acquire a + // strong reference to the C++ object to avoid a situation + // where the Python object is deallocated concurrently. + if (!target_->try_incref_pyobject()) { + return intrusive_ptr(); + } + increfed = true; + } + } } while (!target_->combined_refcount_.compare_exchange_weak( combined_refcount, combined_refcount + detail::kReferenceCountOne, std::memory_order_acquire, std::memory_order_relaxed)); + if constexpr (detail::TargetTraits::can_have_pyobject) { + if (increfed && detail::refcount(combined_refcount) != 1) { + target_->decref_pyobject(); + } + } + return intrusive_ptr( target_, raw::DontIncreaseRefcount{}); } @@ -1060,7 +1166,14 @@ namespace intrusive_ptr { // NullType::singleton to this function inline void incref(intrusive_ptr_target* self) { if (self) { - detail::atomic_refcount_increment(self->combined_refcount_); + uint64_t combined = detail::atomic_combined_refcount_increment( + self->combined_refcount_, detail::kReferenceCountOne); + + if (C10_UNLIKELY( + detail::has_pyobject(combined) && + detail::refcount(combined) == 2)) { + self->incref_pyobject(); + } } } diff --git a/test/test_autograd.py b/test/test_autograd.py index 4926697d1d1be..2bd6367748dbb 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -10895,6 +10895,34 @@ def func(inp): self.assertTrue(gradcheck(func, x, fast_mode=True)) + def test_grad_thread_safety(self): + import threading + from concurrent.futures import ThreadPoolExecutor + + NUM_ITERS = 10 + NUM_THREADS = 4 + + # Concurrent calls to tensor.untyped_storage() + def access_grad(tensor, barrier): + barrier.wait() + return weakref.ref(tensor.grad) + + for i in range(NUM_ITERS): + tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + (tensor**2).sum().backward() + + barrier = threading.Barrier(NUM_THREADS) + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = [ + executor.submit(access_grad, tensor, barrier) + for _ in range(NUM_THREADS) + ] + + # Check that all the grad tensors returned were the same + for future in futures: + self.assertEqual(future.result()(), tensor.grad) + self.assertIsNotNone(tensor.grad) + def index_perm_variable(shape, max_indices): if not isinstance(shape, tuple): diff --git a/test/test_torch.py b/test/test_torch.py index b54ae93baa647..6aac21f3d0682 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -259,7 +259,8 @@ def test_storage_setitem(self, device, dtype): def test_storage_use_count(self, device): a = torch.randn(10, device=device) prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata) - self.assertEqual(prev_cf, 1) + # Two references: 'a' and the wrapper returned by untyped_storage() + self.assertEqual(prev_cf, 2) b = a.view(2, 5) self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1) @@ -9316,7 +9317,7 @@ class BadSubTensor: member_var = object() err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor" - with self.assertRaisesRegex(RuntimeError, err_msg): + with self.assertRaisesRegex(TypeError, err_msg): s0 = t0.as_subclass(BadSubTensor) # FIXME: Port to a test suite that better fits slicing @@ -10316,20 +10317,21 @@ def test_backward_hooks_traverse(self): @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") def test_tensor_dead_weak_ref(self): - x = torch.empty(2) + x = torch.ones(2) w_x = weakref.ref(x) - y = torch.empty(2) + y = torch.ones(2) y.grad = x del x x = w_x() - # Ideally, x would keep the tensor live. But CPython doesn't - # provide enough hooks to do this. So it will go dead and x - # will transmute into an undefined tensor. Not great, but the - # best we can do. + # x should keep the tensor live. This didn't happen in earlier PyTorch + # versions. del y - self.assertRaises(RuntimeError, lambda: x.sigmoid()) + self.assertEqual(2, x.sum()) + + del x + self.assertIsNone(w_x()) @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993") def test_storage_dead_weak_ref(self): @@ -10337,16 +10339,9 @@ def test_storage_dead_weak_ref(self): w_x = weakref.ref(x) y = torch.tensor(x) del x - - x = w_x() - # Ideally, x would keep the storage live. But CPython doesn't - # provide enough hooks to do this. So it will go dead and x - # will transmute into storage with null StorageImpl. Not great, but the - # best we can do. + self.assertIsNotNone(w_x()) del y - - self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0]) - self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float()) + self.assertIsNone(w_x()) def test_tensor_resurrected_weak_ref(self): x = torch.empty(2) @@ -10407,6 +10402,31 @@ def callback(w): self.assertTrue(called) + def test_storage_thread_safety(self): + import threading + from concurrent.futures import ThreadPoolExecutor + + NUM_ITERS = 10 + NUM_THREADS = 4 + + # Concurrent calls to tensor.untyped_storage() + def access_untyped_storage(tensor, barrier): + barrier.wait() + return weakref.ref(tensor.untyped_storage()) + + for i in range(NUM_ITERS): + tensor = torch.tensor([1.0, 2.0, 3.0]) + barrier = threading.Barrier(NUM_THREADS) + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = [ + executor.submit(access_untyped_storage, tensor, barrier) + for _ in range(NUM_THREADS) + ] + + # Check that all the storages returned were the same + for future in futures: + self.assertEqual(future.result()(), tensor.untyped_storage()) + # FIXME: move to test_linalg @torch.inference_mode() def test_bmm_multithreaded(self): diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ad37abe3b560b..0a4698e5d38f4 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -398,36 +398,28 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) { // weak_use_count() adds 1 if use_count is non-zero TORCH_CHECK( - a->cdata->weak_use_count() == 1, + a->cdata.weak_use_count() == 1, "Expected no weakrefs to t1's Tensor object but got ", - a->cdata->weak_use_count() - 1); + a->cdata.weak_use_count() - 1); TORCH_CHECK( - b->cdata->weak_use_count() == 1, + b->cdata.weak_use_count() == 1, "Expected no weakrefs to t2's Tensor object but got ", - b->cdata->weak_use_count() - 1); + b->cdata.weak_use_count() - 1); + + // NB: Creating local copies of *both* Tensors here ensures that they each + // hold a strong reference to their PyObject. This avoids having to fix up + // reference counts when we swap the PyObject slots below. + at::Tensor tmp_a = a->cdata; + at::Tensor tmp_b = b->cdata; // Swap the Tensor Impl - c10::MaybeOwned tmp = a->cdata; - - // The TensorImpls contain PyObjectSlots that have a reference to the PyObject - // associated with the TensorImpl. Swap this field as well. - std::optional mb_obj_a = - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - std::optional mb_obj_b = - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - TORCH_INTERNAL_ASSERT( - mb_obj_a.has_value() && mb_obj_b.has_value(), - "Both tensors should have PyObjects tagged by the current python interpreter"); - TORCH_CHECK(mb_obj_a.value() == a_); - TORCH_CHECK(mb_obj_b.value() == b_); - - a->cdata = b->cdata; - b->cdata = tmp; - - a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_); - b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_); + a->cdata = tmp_b; + b->cdata = tmp_a; + + // Swap the PyObjects associated with each TensorImpl + auto& a_slot = *a->cdata.unsafeGetTensorImpl()->pyobj_slot(); + auto& b_slot = *b->cdata.unsafeGetTensorImpl()->pyobj_slot(); + a_slot.swap(b_slot); Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index 6b23752124228..8a2e0d533ff0c 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -45,7 +45,9 @@ struct ConcretePyInterpreterVTable final std::string name() const override; void incref(PyObject* pyobj) const override; - void decref(PyObject* pyobj, bool has_pyobj_slot) const override; + void decref(PyObject* pyobj) const override; + bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override; + size_t refcnt(PyObject* pyobj) const override; // TODO: Need to make this work for StorageImpl too. I imagine I'll want to // operate upon a PyObjectSlot rather than a TensorImpl @@ -235,53 +237,13 @@ py::object torchDispatchFromTensorImpl( TorchFunctionName::TorchDispatch)); } -// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] -// Before calling PyInterpreter::decref, we must statically know if the -// pyobj has a PyObjectSlot or not. -// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection -// - If it does not have a PyObjectSlot, we can freely decref -// One alternative to this is using PyObject_IsInstance -// to get at this information. However, we don't want to risk an incorrect -// `__instancecheck__` changing the semantics here. -void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot) - const { +void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const { // Leak the pyobj if not initialized. This can happen if we are running // exit handlers that are destructing tensors with residual (owned) // PyObjects stored in them. if (!Py_IsInitialized()) return; - pybind11::gil_scoped_acquire gil; - // Two possibilities: - // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or - // Storage. Then we must be careful about PyObject resurrection (see - // THPVariable_clear). - // 2. We are decref-ing some other Python object. We don't do - // PyObject resurrection on non-Tensors, so we just carry on as usual - if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) { - if (THPVariable_Check(pyobj)) { - // It's still alive! This can happen if a weak ref resurrected - // the PyObject without flipping ownership. At this point it is - // too late to rescue the object, so just stub out the PyObject - // so that it fails on subsequent uses. Don't raise an error here; - // you're probably in a destructor. - TORCH_WARN( - "Deallocating Tensor that still has live PyObject references. " - "This probably happened because you took out a weak reference to " - "Tensor and didn't call _fix_weakref() after dereferencing it. " - "Subsequent accesses to this tensor via the PyObject will now fail."); - (reinterpret_cast(pyobj))->cdata = - c10::MaybeOwned(); - } else if (THPStorage_Check(pyobj)) { - TORCH_WARN( - "Deallocating UntypedStorage that still has live PyObject references. " - "This probably happened because you took out a weak reference to " - "UntypedStorage and didn't call _fix_weakref() after dereferencing it. " - "Subsequent accesses to this storage via the PyObject will now fail."); - (reinterpret_cast(pyobj))->cdata = - c10::MaybeOwned(); - } - } Py_DECREF(pyobj); } @@ -292,6 +254,25 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const { Py_INCREF(pyobj); } +bool ConcretePyInterpreterVTable::try_incref( + const c10::impl::PyObjectSlot& pyobj_slot) const { + if (!Py_IsInitialized()) + return false; + pybind11::gil_scoped_acquire gil; + PyObject* pyobj = pyobj_slot.load_pyobj(); + if (!pyobj) { + return false; + } + return PyUnstable_TryIncRef(pyobj); +} + +size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const { + if (!Py_IsInitialized() || pyobj == nullptr) + return 0; + pybind11::gil_scoped_acquire gil; + return Py_REFCNT(pyobj); +} + bool isPythonTensor(const at::Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } @@ -618,11 +599,7 @@ static void set_tensor_attr_with_capsule( const c10::TensorImpl* tensor, py::capsule& capsule, const char* attr_name) { - std::optional mb_obj = tensor->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - auto obj = mb_obj.value(); + PyObject* obj = tensor->pyobj_slot()->load_pyobj(); py::handle(obj).attr(attr_name) = capsule; } @@ -646,11 +623,7 @@ static c10::ArrayRef get_set_cached_attr( const c10::TensorImpl* tensor, const char* base_attr_name, const py::object& obj) { - std::optional mb_obj = - tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - auto tensor_obj = mb_obj.value(); + PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj(); auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len"); bool is_buffer_allocated = false; diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 02558cbdf8968..671c28adef3e3 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -23,6 +23,8 @@ #include #include +using torch::utils::PyObjectPreservation; + template <> void THPPointer::free() { if (ptr) { @@ -32,238 +34,72 @@ void THPPointer::free() { PyTypeObject* THPStorageClass = nullptr; -PyObject* THPStorage_NewWithStorage( - PyTypeObject* type, - c10::Storage _storage, - bool allow_preexisting_pyobj) { - TORCH_CHECK( - PyType_IsSubtype(type, &THPStorageType), - "Creating a Storage subclass from a class that does not inherit from ", - "Storage is not possible. Make sure your class inherits from Storage."); - - auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (maybe_pyobj.has_value() && maybe_pyobj.value()) { - TORCH_CHECK( - allow_preexisting_pyobj, - "Creating a new Storage subclass ", - type->tp_name, - " but the raw Storage object is already associated to a python object ", - "of type ", - maybe_pyobj.value()->ob_type->tp_name); - PyObject* obj = *maybe_pyobj; - PyTypeObject* obj_type = Py_TYPE(obj); - TORCH_CHECK( - obj_type == type || PyType_IsSubtype(obj_type, type), - "Creating a new Storage subclass ", - type->tp_name, - " but the raw Storage object is already associated to a python object ", - "of type ", - maybe_pyobj.value()->ob_type->tp_name, - " which is not a subclass of the " - "requested type"); - return THPStorage_Wrap(std::move(_storage)); - } - +// Create a new Python Storage object, but don't set the pyobj slot on the +// c10::Storage object. +static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) { PyObject* obj = type->tp_alloc(type, 0); TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); - auto s = reinterpret_cast(obj); + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(obj); - new (&s->cdata) c10::MaybeOwned(); - - s->cdata = c10::MaybeOwned::owned(std::move(_storage)); + auto s = (THPStorage*)obj; + new (&s->cdata) c10::Storage(std::move(_storage)); + return obj; +} - if (!c10::impl::HermeticPyObjectTLS::get_state()) { - s->is_hermetic = false; - const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); - } else { - s->is_hermetic = true; - } +// Create a new Python Storage object for a new c10::Storage, and set the +// pyobj slot. The c10::Storage must not already have a pyobj set. +PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) { + TORCH_CHECK( + type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType), + "Creating a Storage subclass from a class that does not inherit from ", + "Storage is not possible. Make sure your class inherits from Storage."); + TORCH_INTERNAL_ASSERT(_storage.use_count() == 1); + c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl(); + PyObject* obj = THPStorage_New(type, std::move(_storage)); + PyObjectPreservation::init_fresh_nonatomic( + storage_impl, storage_impl->pyobj_slot(), obj); return obj; } -// Wraps the c10::Storage with a storage PyObject +// Returns a PyObject wrapper for the c10::Storage object. The existing +// wrapper is returned if it already exists. PyObject* THPStorage_Wrap(c10::Storage storage) { - c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + return THPStorage_New(THPStorageClass, std::move(storage)); } - c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - std::optional maybe_pyobj = pyobj_slot->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (maybe_pyobj.has_value()) { - auto obj = *maybe_pyobj; - if (obj) { - TORCH_CHECK( - THPStorage_Check(obj), - "Expected a storage type, but got ", - Py_TYPE(obj)->tp_name); - - if (pyobj_slot->owns_pyobj()) { - pyobj_slot->set_owns_pyobj(false); - reinterpret_cast(obj)->cdata = - c10::MaybeOwned::owned(std::move(storage)); - return obj; - } else { - Py_INCREF(obj); - return obj; - } - } - } - return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); -} - -static bool THPStorage_isPreservable(THPStorage* self) { - if (self->cdata.unsafeIsBorrowed()) { - return false; - } - auto const& storage = THPStorage_Unpack(self); - - if (self->is_hermetic) { - return false; - } + c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); + c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/true) != reinterpret_cast(self)) { - return false; - } - if (storage.use_count() <= 1) { - return false; + PyObject* obj = pyobj_slot->load_pyobj(); + if (obj) { + return Py_NewRef(obj); } - return true; -} -static bool THPStorage_tryPreserve(THPStorage* self) { - if (!THPStorage_isPreservable(self)) { - return false; + obj = THPStorage_New(THPStorageClass, std::move(storage)); + PyObject* wrapper = + PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj); + if (wrapper != obj) { + // Another thread beat us to it + Py_DECREF(obj); + return Py_NewRef(wrapper); } - - const auto& storage = THPStorage_Unpack(self); - c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); - - auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/true); - // NOTE: It is possible to just set the PyObjectSlot here, but the point is - // that we should have already set PyObjectSlot when the storage PyObject - // was created. - TORCH_INTERNAL_ASSERT( - maybe_pyobj.has_value(), - "Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject"); - - PyObject* pyobj = *maybe_pyobj; - - TORCH_CHECK( - THPStorage_Check(pyobj), - "Expected a storage type, but got ", - Py_TYPE(pyobj)->tp_name); - - TORCH_INTERNAL_ASSERT( - (void*)pyobj == (void*)self, - "Python storage and the PyObject in the internal PyObjectSlot are not at the same address"); - - TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj()); - - storage_impl->pyobj_slot()->set_owns_pyobj(true); - // When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to - // ensure the PyObject is in a valid state - _Py_NewReference(reinterpret_cast(self)); - - self->cdata = c10::MaybeOwned::borrowed(storage); - return true; + return obj; } -static void THPStorage_subclass_dealloc(PyObject* self) { +static void THPStorage_dealloc(PyObject* self) { THPStorage* _self = reinterpret_cast(self); - - if (THPStorage_tryPreserve(_self)) { - return; - } - - // Some subclass of StorageBase could be GC-tracked objects even - // though the base class is not - auto* type = Py_TYPE(self); - if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { - PyObject_GC_UnTrack(self); - } - - bool has_finalizer = type->tp_finalize || type->tp_del; - - if (type->tp_finalize) { - PyObject_GC_Track(self); - if (PyObject_CallFinalizerFromDealloc(self) < 0) { - // The finalizer has resurrected the PyObject and there is a new Python - // reference to it, so we can just stop deallocating. Read about - // resurrection from `__del__` here: - // https://docs.python.org/3/reference/datamodel.html#object.__del__ - return; - } - PyObject_GC_UnTrack(self); - } - - // base test is unnecessary as THPStorae does not set this - if (type->tp_weaklistoffset) { - PyObject_ClearWeakRefs(self); + auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot(); + if (pyobj_slot->load_pyobj() == self) { + TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1); + pyobj_slot->clear(); } - - if (type->tp_del) { - PyObject_GC_Track(self); - type->tp_del(self); - if (Py_REFCNT(self) > 0) { - // Resurrected (see above comment about resurrection from `__del__`) - return; - } - PyObject_GC_UnTrack(self); - } - - if (has_finalizer) { - /* New weakrefs could be created during the finalizer call. - If this occurs, clear them out without calling their - finalizers since they might rely on part of the object - being finalized that has already been destroyed. */ - if (type->tp_weaklistoffset) { - /* Modeled after GET_WEAKREFS_LISTPTR() */ - PyWeakReference** list = reinterpret_cast( - PyObject_GET_WEAKREFS_LISTPTR(self)); - while (*list) - _PyWeakref_ClearRef(*list); - } - } - - // Clear slots - { - PyTypeObject* base = type; - while (base != &THPStorageType) { - if (Py_SIZE(base)) { - clear_slots(base, self); - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // Clear __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr != nullptr) { - PyObject* dict = *dictptr; - if (dict != nullptr) { - Py_DECREF(dict); - *dictptr = nullptr; - } - } - } - - TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type); - - _self->cdata.~MaybeOwned(); + _self->cdata.~Storage(); Py_TYPE(_self)->tp_free(self); - - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_DECREF(type); } static PyObject* THPStorage_pynew( @@ -553,64 +389,13 @@ static PyMappingMethods THPStorage_mappingmethods = { reinterpret_cast(THPStorage_get), reinterpret_cast(THPStorage_set)}; -struct THPStorageMeta { - PyHeapTypeObject base; -}; - -static int THPStorageMetaType_init( - PyObject* cls, - PyObject* args, - PyObject* kwargs); - -static PyTypeObject THPStorageMetaType = { - PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) - "torch._C._StorageMeta", /* tp_name */ - sizeof(THPStorageMeta), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - // NOLINTNEXTLINE(misc-redundant-expression) - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - nullptr, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - THPStorageMetaType_init, /* tp_init */ - nullptr, /* tp_alloc */ - nullptr, /* tp_new */ -}; - // TODO: implement equality PyTypeObject THPStorageType = { - PyVarObject_HEAD_INIT(&THPStorageMetaType, 0) + PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0) "torch._C.StorageBase", /* tp_name */ sizeof(THPStorage), /* tp_basicsize */ 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ + THPStorage_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ @@ -649,15 +434,6 @@ PyTypeObject THPStorageType = { THPStorage_pynew, /* tp_new */ }; -int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { - if (PyType_Type.tp_init(cls, args, kwargs) < 0) { - return -1; - } - (reinterpret_cast(cls))->tp_dealloc = - static_cast(THPStorage_subclass_dealloc); - return 0; -} - static PyObject* THPStorage_device(THPStorage* self, void* unused) { HANDLE_TH_ERRORS THPStorage_assertNotNull(self); @@ -692,13 +468,6 @@ bool THPStorage_init(PyObject* module) { THPUtils_addPyMethodDefs(methods, THPStorage_getMethods()); THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods()); - THPStorageMetaType.tp_base = &PyType_Type; - if (PyType_Ready(&THPStorageMetaType) < 0) - return false; - Py_INCREF(&THPStorageMetaType); - PyModule_AddObject( - module, "_StorageMeta", reinterpret_cast(&THPStorageMetaType)); - THPStorageType.tp_methods = methods.data(); THPStorageType.tp_getset = THPStorage_properties; if (PyType_Ready(&THPStorageType) < 0) diff --git a/torch/csrc/Storage.h b/torch/csrc/Storage.h index 698cd80548efa..89e853181f3da 100644 --- a/torch/csrc/Storage.h +++ b/torch/csrc/Storage.h @@ -11,15 +11,13 @@ struct THPStorage { PyObject_HEAD - c10::MaybeOwned cdata; - bool is_hermetic; + c10::Storage cdata; }; TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage); TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage( PyTypeObject* type, - c10::Storage _storage, - bool allow_preexisting_pyobj = false); + c10::Storage _storage); TORCH_PYTHON_API extern PyTypeObject* THPStorageClass; inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) { @@ -49,7 +47,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj); TORCH_PYTHON_API extern PyTypeObject THPStorageType; inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) { - return *storage->cdata; + return storage->cdata; } inline const c10::Storage& THPStorage_Unpack(PyObject* obj) { diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index 68c06f7c88c1c..178f735802fb7 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -529,9 +529,8 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) { THPUtils_typename(new_cdata)); c10::StorageImpl* ptr = static_cast(PyLong_AsVoidPtr(new_cdata)); - self->cdata.~MaybeOwned(); - self->cdata = c10::MaybeOwned::owned( - c10::Storage(c10::intrusive_ptr::reclaim_copy(ptr))); + self->cdata = + c10::Storage(c10::intrusive_ptr::reclaim_copy(ptr)); Py_INCREF(self); return reinterpret_cast(self); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h index 97e689d36050c..8f55f22ae4ad4 100644 --- a/torch/csrc/autograd/functions/accumulate_grad.h +++ b/torch/csrc/autograd/functions/accumulate_grad.h @@ -180,7 +180,9 @@ struct TORCH_API AccumulateGrad : public Node { if (!GradMode::is_enabled() && !new_grad.is_sparse() && !new_grad.is_sparse_csr() && !(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) && - at::caching::adjusted_use_count(new_grad) <= num_expected_refs && + impl::is_tensor_stealable( + new_grad, + num_expected_refs + at::caching::is_cached_tensor(new_grad)) && (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) { // See Case 1.1: Stealable dense new_grad @@ -193,7 +195,7 @@ struct TORCH_API AccumulateGrad : public Node { // SparseTensor should be the only one holding a reference to these. new_grad._indices().use_count() <= 1 && new_grad._values().use_count() <= 1 && - new_grad.use_count() <= num_expected_refs) { + impl::is_tensor_stealable(new_grad, num_expected_refs)) { // Case 1.2: Stealable sparse new_grad // No scenario where we expect this to be true currently TORCH_INTERNAL_ASSERT_DEBUG_ONLY( diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 62770ef946592..a477bf4c3e507 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) { v.is_non_overlapping_and_dense() && // and we hold the last reference - at::caching::adjusted_use_count(v) == 1 && v.has_storage() && - v.storage().use_count() == 1); + impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) && + v.has_storage() && v.storage().use_count() == 1); } } // anonymous namespace diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 837ba93d1cc28..e109acdd1e302 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -50,6 +50,7 @@ using namespace at; using namespace torch; using namespace torch::autograd; +using torch::utils::PyObjectPreservation; namespace { class OperatorArgsKwargsView { @@ -317,20 +318,15 @@ PyObject* THPVariableClass = nullptr; PyObject* ParameterClass = nullptr; -static PyObject* THPVariable_NewWithVar( - PyTypeObject* type, - const at::TensorBase& _var, - bool allow_preexisting_pyobj = false, - std::optional has_torch_dispatch_if_known = std::nullopt); - // clang-tidy gets confused by static const static constexpr const char* VOLATILE_WARNING = "volatile was removed and now has no effect. Use " "`with torch.no_grad():` instead."; +static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls); + static bool check_has_torch_dispatch(PyObject* obj) { - PyTypeObject* tp = Py_TYPE(obj); - if (THPVariable_CheckTypeExact(tp)) { + if (THPVariable_CheckExact(obj)) { return false; } py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__"); @@ -366,152 +362,82 @@ void activateGPUTrace() { c10::impl::GPUTrace::set_trace(getPyInterpreter()); } -PyObject* THPVariable_Wrap(const at::TensorBase& var) { +static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) { + TORCH_CHECK( + PyObject_TypeCheck(obj, type), + "Creating a new Tensor subclass ", + type->tp_name, + " but the raw Tensor object is already associated to a python object ", + "of type ", + Py_TYPE(obj)->tp_name, + " which is not a subclass of the requested type"); +} + +// Generic for const Tensor& or Tensor&& +template +static PyObject* THPVariable_WrapWithType( + T&& var, + std::optional desired_type) { if (!var.defined()) { Py_RETURN_NONE; } - if (c10::impl::HermeticPyObjectTLS::get_state()) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); - } - - std::optional mb_obj = - var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - if (mb_obj.has_value()) { - auto obj = *mb_obj; - if (obj) { - if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) { - // C++ owns the Python object; this implies there weren't any other - // owning references to the Python object. Since we're making the - // object "live" again on Python side, let's flip back the ownership - // (Python owns C++) as it would now be unsound to deallocate the C++ - // object if all C++ references go to zero - var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false); - reinterpret_cast(obj)->cdata = - MaybeOwned::owned(Variable(var)); - // NB: incref is not necessary, because we are "stealing" the previous - // ownership from the Variable to return it here for the wrap - return obj; - } - Py_INCREF(obj); - return obj; - } - // TODO: a better invariant is that if we tagged, we MUST have a valid - // PyObject. That's PyObject preservation - // (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR - // being a thing, the PyObject field will get cleared when all references - // to the Python object are removed. - } + c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl(); + c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot(); - if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + PyObject* obj = pyobj_slot->load_pyobj(); + if (obj) { + if (desired_type) { + check_tensor_subclass(obj, *desired_type); + } + return Py_NewRef(obj); } - if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, var); + PyTypeObject* type = reinterpret_cast(THPVariableClass); + if (desired_type) { + type = *desired_type; + } else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) { + if (auto clazz = getPythonTensorClass(var.device())) { + type = reinterpret_cast(clazz); + } } - return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); -} + obj = type->tp_alloc(type, 0); + TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); -static bool isResurrectable(THPVariable* self) { - // We want to divide this check into 2 cases. + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(obj); - // 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is - // true). You might think that in this case, it is impossible for tp_clear to - // be called: surely the C++ reference to the PyObject is keeping it live? And - // you'd be right! In fact, when C++ owns the PyObject, we have an invariant - // that the refcount on the PyObject should be precisely one (because if you - // take out another reference to the PyObject, we're supposed to flip the - // ownership pointer back). In reality, you can violate this invariant - // temporarily with weak references, so we don't test for it in asserts. + auto v = reinterpret_cast(obj); + new (&v->cdata) Tensor(std::forward(var)); - // 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is - // false). In this case, tp_clear can get called if the PyObject is referenced - // from a dead cycle, and nowhere else. But if resurrection did not occur, - // then the reference to C++ from the PyObject must be the ONLY reference to - // the C++ object. - if (self->cdata.unsafeIsBorrowed()) { - return false; - } - auto const& tensor = THPVariable_Unpack(self); - if (!tensor.defined() || tensor.use_count() <= 1) { - return false; - } - // Check if this is hermetic. If it is, no resurrection. - if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) != (PyObject*)self) { - return false; + if (THPVariable_Unpack(obj).is_uniquely_owned()) { + // We can use a faster non-atomic code path if we have the only reference to + // a fresh Tensor. + PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj); + return obj; } - return true; -} -// returns true if successfully rezzed; if so, cancel the -// rest of deallocation -static bool THPVariable_tryResurrect(THPVariable* self) { - const auto& tensor = THPVariable_Unpack(self); - - if (!isResurrectable(self)) { - return false; + PyObject* wrapper = + PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj); + if (wrapper != obj) { + // Another thread beat us to it + Py_DECREF(obj); + if (desired_type) { + check_tensor_subclass(wrapper, *desired_type); + } + return Py_NewRef(wrapper); } - - // At this point, we are definitely going to resurrect the tensor. So, the - // tensor better be defined :) - TORCH_INTERNAL_ASSERT(tensor.defined()); - - // There are other C++ owners of the tensor. Flip ownership - // so that C++ owns this Python object, and cancel deallocation. - TORCH_INTERNAL_ASSERT( - !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); - - c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); - auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - - TORCH_INTERNAL_ASSERT( - maybe_pyobj.has_value(), - "Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject"); - - tensor_impl->pyobj_slot()->set_owns_pyobj(true); - - // Resurrect the Python object. This is something CPython does - // internally occasionally, see - // https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259 - // so we just copy the pattern here. Note that we don't have to worry - // about saving and restoring the refcount (as the quoted code does) - // because we actually DO need to reset the refcount to one here, we - // can't assume that some other code has taken care of it. - // NB: this will overreport _Py_RefTotal but based on inspection of object.c - // there is no way to avoid this - - // When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to - // ensure the PyObject is in a valid state - _Py_NewReference((PyObject*)self); - - // Flip THPVariable to be non-owning - // (near use-after-free miss here: fresh MaybeOwned is created breaking - // reference on Tensor in struct BEFORE we overwrite the old one) - TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state()); - self->cdata = MaybeOwned::borrowed(tensor); - - // NB: At this point, tensor *could* be dead (e.g., some other C++ thread - // decrefed it.) At this point, it is probably waiting on the GIL to - // deallocate the Python object and will kill self, BUT NOT YET. - - return true; + return obj; } -static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) { - TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_traverse function was not overridden properly"); - return 0; +PyObject* THPVariable_Wrap(at::TensorBase&& var) { + return THPVariable_WrapWithType(std::move(var), std::nullopt); } -static int THPFake_clear(THPVariable* self) { - TORCH_INTERNAL_ASSERT( - false, "TensorBase tp_clear function was not overridden properly"); - return 0; +PyObject* THPVariable_Wrap(const at::TensorBase& var) { + return THPVariable_WrapWithType(var, std::nullopt); } static PyObject* THPVariable_pynew( @@ -673,16 +599,16 @@ static PyObject* THPVariable_as_subclass( ParsedArgs<1> parsed_args{}; auto r = parser.parse(_self, args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // guard completely turns off torch dispatch modes, doesn't just pop off the // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; c10::impl::DisablePythonDispatcher dpd_g; - return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias()); + PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } @@ -697,11 +623,7 @@ static PyObject* THPVariable_make_subclass( ParsedArgs<7> parsed_args{}; auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // guard completely turns off torch dispatch modes, doesn't just pop off the // stack torch_dispatch_mode::StashTorchDispatchStackGuard td_g; @@ -734,7 +656,11 @@ static PyObject* THPVariable_make_subclass( data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6)); } - return THPVariable_NewWithVar((PyTypeObject*)cls, data); + PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } @@ -831,11 +757,7 @@ static PyObject* THPVariable_make_wrapper_subclass( auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); // This is an important safety check; without it, the default behavior will be // to continue on to the underlying CPU/CUDA kernel advertised by the dispatch @@ -873,6 +795,8 @@ static PyObject* THPVariable_make_wrapper_subclass( /*storage_size=*/r.toSymIntOptional(14), r.toDispatchKeySetOptional(13)); + tensor.unsafeGetTensorImpl()->set_python_dispatch(true); + const auto sizes_strides_policy = r.stringViewOptional(10); if (sizes_strides_policy.has_value()) { tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( @@ -888,13 +812,7 @@ static PyObject* THPVariable_make_wrapper_subclass( tensor.unsafeGetTensorImpl()->set_python_custom_layout(true); } - return THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - // false is the default - /*allow_preexisting_pyobj=*/false, - // we checked __torch_dispatch__ above; avoid checking again. - /*has_torch_dispatch_if_known=*/true); + return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls); END_HANDLE_TH_ERRORS } @@ -1024,11 +942,7 @@ static PyObject* THPVariable_dtensor_new( auto r = parser.parse(args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); - TORCH_CHECK_TYPE( - PyType_Check(cls), - "cls must be a type (got ", - Py_TYPE(cls)->tp_name, - ")"); + TORCH_CHECK_TENSOR_SUBTYPE(cls); #ifndef NDEBUG // This is specifically for making a DTensor, which we know defines @@ -1081,14 +995,9 @@ static PyObject* THPVariable_dtensor_new( /*storage_size=*/std::nullopt, extra_dispatch_keys); tensor.set_requires_grad(requires_grad); - py::object py_tensor = - py::reinterpret_steal(THPVariable_NewWithVar( - (PyTypeObject*)cls, - tensor, - // false is the default - /*allow_preexisting_pyobj=*/false, - // we know DTensor has __torch_dispatch__; avoid checking again. - /*has_torch_dispatch_if_known=*/true)); + tensor.unsafeGetTensorImpl()->set_python_dispatch(true); + py::object py_tensor = py::reinterpret_steal( + THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls)); py_tensor.attr(dtensor_interned_strings._spec) = spec; py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor; return py_tensor.release().ptr(); @@ -2381,15 +2290,16 @@ static PyTypeObject THPVariableMetaType = { nullptr, /* tp_new */ }; +static void THPVariable_dealloc(PyObject* self); +static int THPVariable_clear(THPVariable* self); +static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg); + static PyTypeObject THPVariableType = { PyVarObject_HEAD_INIT(&THPVariableMetaType, 0) "torch._C.TensorBase", /* tp_name */ sizeof(THPVariable), /* tp_basicsize */ 0, /* tp_itemsize */ - // This is unspecified, because it is illegal to create a THPVariableType - // directly. Subclasses will have their tp_dealloc set appropriately - // by the metaclass - nullptr, /* tp_dealloc */ + THPVariable_dealloc, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ @@ -2408,9 +2318,8 @@ static PyTypeObject THPVariableType = { Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ nullptr, /* tp_doc */ - // Also set by metaclass - (traverseproc)THPFake_traverse, /* tp_traverse */ - (inquiry)THPFake_clear, /* tp_clear */ + (traverseproc)THPVariable_traverse, /* tp_traverse */ + (inquiry)THPVariable_clear, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ @@ -2439,345 +2348,68 @@ PyObject* THPVariable_pynew( type != &THPVariableType, "Cannot directly construct TensorBase; subclass it and then construct that"); jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); - auto tensor = torch::utils::base_tensor_ctor(args, kwargs); // WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was // given a raw pointer that will refcount bump // NB: base_tensor_ctor can call into dispatched ATen functions (e.g., // alias(), lift_fresh()) which can return Tensor subclasses. We allow // these to be passed on directly. - return THPVariable_NewWithVar( - type, - tensor, - /*allow_preexisting_pyobj=*/true); + PyObject* obj = THPVariable_WrapWithType( + torch::utils::base_tensor_ctor(args, kwargs), type); + if (check_has_torch_dispatch(obj)) { + THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true); + } + return obj; END_HANDLE_TH_ERRORS } -static int THPVariable_subclass_clear(THPVariable* self) { - // Is it OK for an object to still be live after running - // tp_clear? Yes. When Python is breaking reference cycles, it can't assume - // that an object will dealloc after it's cleared. The source code explicitly - // handles this case: - // https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025 - - // Note that we don't need to actually resurrect here. There are 2 cases: - // 1. The PyObject is not part of a reference cycle. In this case, we don't - // need to do anything. The GC will move on to try and break the reference - // cycle on another object, which will eventually trigger tp_dealloc (and thus - // resurrection). - - // 2. The PyObject is part of a reference cycle. This case should not actually - // be possible, due to the logic in our tp_traverse - // (THPVariable_subclass_traverse). - - // In fact, resurrecting here breaks the invariant that "C++ owns Python only - // when PyObject's refcount would otherwise be 0". Most immediately, as we're - // merely breaking reference cycles here, there can be other references to the - // PyObject. *However*, if other objects in the refcycle resurrect, then we - // will be in a state where the PyObject has multiple Python references, yet - // C++ owns the PyObject. - - // See https://github.com/pytorch/pytorch/pull/75933 for more discussion. - if (isResurrectable(self)) { - return 0; - } - +static int THPVariable_clear(THPVariable* self) { // First clear Tensor specific things - Py_CLEAR(self->backward_hooks); Py_CLEAR(self->post_accumulate_grad_hooks); - const auto& tensor = THPVariable_Unpack(self); - if (tensor.defined()) { - // Two situations to consider: - // PyObject -owns-> Tensor - // unsafeIsBorrowed() is FALSE. We're obligated to look through - // Tensor to break references. Clearing cdata must induce the - // destruction of the C++ Tensor. If there were other references - // to C++ tensor, the Python object would have been resurrected - // by flipping the ownership. - // Tensor -owns-> PyObject - // unsafeIsBorrowed() is TRUE. We're deallocating the PyObject - // because Tensor asked us to (it's already destructing). - - if (!self->cdata.unsafeIsBorrowed() && - tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false) == (PyObject*)self) { - // TODO: empirically, on OS X this assert appears to be untrue - // In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn - // distributed/rpc/test_process_group_agent.py - // - // libc++abi.dylib: terminating with uncaught exception of type - // c10::Error: - // !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL - // ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171, - // please report a bug to PyTorch. Exception raised from - // THPVariable_subclass_clear at - // ../torch/csrc/autograd/python_variable.cpp:171 (most recent call - // first): frame #0: c10::Error::Error(c10::SourceLocation, - // std::__1::basic_string, - // std::__1::allocator >) + 98 (0x1158a0442 in libc10.dylib) frame - // #1: c10::detail::torchCheckFail(char const*, char const*, unsigned - // int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2: - // c10::detail::torchInternalAssertFail(char const*, char const*, - // unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9 - // (0x1141e3f89 in libtorch_python.dylib) frame #3: - // THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in - // libtorch_python.dylib) frame #4: - // THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in - // libtorch_python.dylib) frame #5: (anonymous - // namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*, - // _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6: - // c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in - // libc10.dylib) frame #7: - // c10::MaybeOwned::operator=(c10::MaybeOwned&&) - // + 91 (0x11488c11b in libtorch_python.dylib) frame #8: - // THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in - // libtorch_python.dylib) frame #47: start + 1 - // (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???) - // TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()); - if (auto grad_acc = - torch::autograd::impl::try_get_grad_accumulator(tensor)) { - grad_acc->pre_hooks().clear(); - grad_acc->tensor_pre_hooks().clear(); - grad_acc->retains_grad_hooks().clear(); - } + if (self->cdata.defined()) { + auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot(); + // Typically the Tensor's pyobj_slot points back to this object. The only + // time that's not the case is if we had a race in THPVariable_Wrap and we + // need to discard the Python object because some other thread beat us to + // setting the pyobj_slot. + if (pyobj_slot->load_pyobj() == (PyObject*)self) { + // A Tensor's Python object should only be destroyed when the Tensor has + // no other references too. + TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1); + + // Clear the pyobj_slot so that a try_incref() call from + // weak_intrusive_ptr::lock() won't see a freed pointer. + pyobj_slot->clear(); } } - TORCH_INTERNAL_ASSERT(!isResurrectable(self)); { // MapAllocator can take significant time to release large tensors; // release the GIL here to avoid impacting main thread perf. pybind11::gil_scoped_release no_gil; - self->cdata = MaybeOwned(); + self->cdata = Variable(); } - // Since we override the basic subtype_clear from CPython, we need a crappy - // version here just like for traverse and dealloc - - // Clear all slots until we get to the base Tensor class - PyTypeObject* type = Py_TYPE((PyObject*)self); - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) - clear_slots(base, (PyObject*)self); - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - - // Assume we never have managed dict for Tensors as we don't set the flag on - // the base class - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self); - if (dictptr && *dictptr) - Py_CLEAR(*dictptr); - } - return 0; } -// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc -// on subclasses. It's never valid to construct a THPVariable so it's not -// necessary to implement the dealloc for that case -static void THPVariable_subclass_dealloc(PyObject* self) { - if (THPVariable_tryResurrect((THPVariable*)self)) - return; - - // This is like a crappy version of subtype_dealloc. - // Unfortunately, we cannot directly delegate to - // subtype_dealloc as it will start walking the parent - // chain *starting with* the type of self, which will cause - // us to go back to our custom dealloc. - // - // We have to replicate the subtype_dealloc logic to ensure - // that finalizers are handled correctly - PyTypeObject* type = Py_TYPE(self); - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented"); - +static void THPVariable_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); - // TODO: consider using trash can - - bool has_finalizer = type->tp_finalize || type->tp_del; - - if (type->tp_finalize) { - PyObject_GC_Track(self); - if (PyObject_CallFinalizerFromDealloc(self) < 0) { - /* Resurrected */ - return; - } - PyObject_GC_UnTrack(self); - } - - // base test is unnecessary as THPVariable does not set this - if (type->tp_weaklistoffset) { - PyObject_ClearWeakRefs(self); - } - - if (type->tp_del) { - PyObject_GC_Track(self); - type->tp_del(self); - if (Py_REFCNT(self) > 0) { - /* Resurrected */ - return; - } - PyObject_GC_UnTrack(self); - } - - if (has_finalizer) { - /* New weakrefs could be created during the finalizer call. - If this occurs, clear them out without calling their - finalizers since they might rely on part of the object - being finalized that has already been destroyed. */ - if (type->tp_weaklistoffset) { - /* Modeled after GET_WEAKREFS_LISTPTR() */ - PyWeakReference** list = - (PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self); - while (*list) - _PyWeakref_ClearRef(*list); - } - } - - // Clear all slots until we get to base class THPVariableType - { - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) { - clear_slots(base, self); - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // All Python defined classes have __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr != nullptr) { - PyObject* dict = *dictptr; - if (dict != nullptr) { - Py_DECREF(dict); - *dictptr = nullptr; - } - } - } - - // subtype_dealloc allows for this but we don't - TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type); - - // Finally clear out the base THPVariable - THPVariable_subclass_clear((THPVariable*)self); - ((THPVariable*)self)->cdata.~MaybeOwned(); + THPVariable_clear((THPVariable*)self); + ((THPVariable*)self)->cdata.~Variable(); Py_TYPE(self)->tp_free(self); - - // Python defined subclasses should always be on the heap - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_DECREF(type); } -// Creates a new Python object for a Variable. -static PyObject* THPVariable_NewWithVar( - PyTypeObject* type, - const at::TensorBase& _var, - bool allow_preexisting_pyobj, - std::optional has_torch_dispatch_if_known) { - // Make sure that the reinterpret into a THPVariable* will be valid - TORCH_CHECK( - type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType), - "Creating a Tensor subclass from a class ", - "that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor."); - - // This function overwrite the Tensor's pyobj field without extra checks - // Make sure it is not set otherwise we would leak memory - auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj( - /*ignore_hermetic_tls=*/false); - - // Under some circumstances, we may attempt to create a new Python - // object for a variable that already has a Python object. The most common - // situation this can occur is if you have a TorchDispatchMode active that - // is returning a subclass from lift_fresh (which is invoked to - // appropriately "wrap" a constant tensor into whatever ambient modes are - // active.) - // - // In general, it is impossible to handle this case compositionally. - // Suppose you have a user call ATensor([1, 2, 3]) when a mode is active - // that is transforming all ops (including the internal lift_fresh call that - // transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output - // BTensor, where ATensor and BTensor are completely unrelated subclasses - // and there is no way to compose them. There is no way to satisfy the user - // request here: in particular, you can't just try to re-invoke the ATensor - // constructor on the returned BTensor, because (1) this could cause an - // infinite loop--we are already in ATensor.__new__ and (2) there isn't any - // guarantee that ATensor.__new__ supports a single element constructor - // anyway. - // - // However, a more common case is a user just called torch.Tensor([1, 2, 3]), - // and a fake tensor mode is active. Really, all you want is to get back - // a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3) - // would have returned a fake tensor (concretely, the way this happens - // is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it - // turns into a FakeTensor when we call lift_fresh on this real tensor). - // This case is compositional because FakeTensor is a subclass of Tensor, so - // it's valid for us to return it in place of a Tensor. So this is what we - // do. - - if (mb_obj.has_value() && mb_obj.value()) { - TORCH_CHECK( - allow_preexisting_pyobj, - "Creating a new Tensor subclass ", - type->tp_name, - " but the raw Tensor object is already associated to a python object ", - "of type ", - mb_obj.value()->ob_type->tp_name); - // Even if we allow pre-existing PyObject, we don't allow completely - // ignoring the requested type. Check that we fulfilled a subtype - // relation here. In the common case the requested type is Tensor and - // this always succeeds. - PyObject* obj = *mb_obj; - // Check if it's OK to just directly return the Python object without - // allocating a new variable. We just check that the existing Python - // object is a subclass of the requested type. - PyTypeObject* obj_type = Py_TYPE(obj); - TORCH_CHECK( - obj_type == type || PyType_IsSubtype(obj_type, type), - "Creating a new Tensor subclass ", - type->tp_name, - " but the raw Tensor object is already associated to a python object ", - "of type ", - mb_obj.value()->ob_type->tp_name, - " which is not a subclass of the " - "requested type"); - // We may (in fact, we typically will) need to resurrect this - return THPVariable_Wrap(_var); - } - - PyObject* obj = type->tp_alloc(type, 0); - if (obj) { - auto v = (THPVariable*)obj; - // TODO: named constructor to avoid default initialization - new (&v->cdata) MaybeOwned(); - if (c10::impl::HermeticPyObjectTLS::get_state()) { - // Do NOT initialize pyobj field on the tensor, you own the C++ - v->cdata = MaybeOwned::owned(Variable(_var)); - TORCH_INTERNAL_ASSERT( - !check_has_torch_dispatch(obj), - "While HermeticPyObject was enabled, we attempted to create a tensor " - "subclass with __torch_dispatch__. This violates the invariant that " - "operations in HermeticPyObject have equivalent C++ implementations. " - "If your operator registered from Python operator registration isn't " - "doing anything strange, there may be an internal PyTorch bug involving " - "not appropriately disabling TorchDispatchMode before executing " - "Python op registration."); - } else { - // Normal codepath - v->cdata = MaybeOwned::owned(Variable(_var)); - const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); - if (has_torch_dispatch_if_known.has_value() - ? *has_torch_dispatch_if_known - : check_has_torch_dispatch(obj)) { - var.unsafeGetTensorImpl()->set_python_dispatch(true); - } - } - } - return obj; +static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) { + TORCH_CHECK_TYPE( + PyType_Check(cls), + "cls must be a type (got ", + Py_TYPE(cls)->tp_name, + ")"); + PyTypeObject* type = reinterpret_cast(cls); + TORCH_CHECK_TYPE( + type == &THPVariableType || cls == THPVariableClass || + PyType_IsSubtype(type, &THPVariableType), + "Creating a Tensor subclass from a class that does not inherit from " + "Tensor is not possible. Make sure your class inherits from Tensor."); } /// NOTE [ PyObject Traversal ] @@ -2796,7 +2428,7 @@ static PyObject* THPVariable_NewWithVar( /// into account these C++ ownership links. /// /// The main danger here comes from the fact that, while all python-related code -/// is thread safe wrt the GC execution (thanks to the GIL), other threads might +/// is thread safe wrt the GC execution, other threads might /// be using our C++ objects arbitrarily which can lead to shared_ptr ref count /// going up or down in between the different traverse/clear invocations. The /// one constraint we add here that is not explicitly mentioned in the GC @@ -2826,124 +2458,46 @@ static PyObject* THPVariable_NewWithVar( /// https://github.com/pytorch/pytorch/issues/7343 /// -static int traverse_slots( - PyTypeObject* type, - PyObject* self, - visitproc visit, - void* arg) { - auto n = Py_SIZE(type); - auto mp = type->tp_members; - for (Py_ssize_t i = 0; i < n; i++, mp++) { - if (mp->type == T_OBJECT_EX) { - char* addr = (char*)self + mp->offset; - PyObject* obj = *(PyObject**)addr; - if (obj != nullptr) { - int err = visit(obj, arg); - if (err) - return err; - } - } - } - return 0; -} - -static int THPVariable_subclass_traverse( - PyObject* self, - visitproc visit, - void* arg) { - // If the tensor is eligible to be resurrected, don't traverse it; instead - // treat all of its references as a root (as they WOULD be a root since we - // can treat the inbound C++ references as root owners). - // - // This works because unlike conventional GCs, Python's GC operates in two - // phases: first it uses traverse to discover roots, and then it uses traverse - // to do reachability. Bypassing traverse during root discovery forces Python - // to treat self as a root for everything it refers to. For a full - // explanation of the algorithm see - // https://devguide.python.org/garbage_collector/ - // - // NB: if we don't hold an owning reference to the underlying Tensor, it is - // possible that the underlying Tensor has already gone dead. In that case, - // it's not safe to access it. But it's also safe to traverse, because if - // the underlying Tensor *is* live, then root discovery will determine that - // self is live, and nothing will get GC'ed anyway (resurrection cannot happen - // if the C++ objects owns the PyObject) +static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) { THPVariable* var = reinterpret_cast(self); - if (isResurrectable(var)) { - return 0; - } - - // Crappy version of subtype_traverse; same deal as - // THPVariable_subclass_dealloc - - PyTypeObject* type = Py_TYPE(self); - // Traverse slots until we get to base class THPVariableType - { - PyTypeObject* base = type; - while (base != &THPVariableType) { - if (Py_SIZE(base)) { - int err = traverse_slots(base, self, visit, arg); - if (err) - return err; - } - base = base->tp_base; - TORCH_INTERNAL_ASSERT(base); - } - } - - // All Python defined classes have __dict__ - if (C10_LIKELY(type->tp_dictoffset)) { - PyObject** dictptr = _PyObject_GetDictPtr(self); - if (dictptr && *dictptr) - Py_VISIT(*dictptr); - } - - TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE); - Py_VISIT(type); - - // Finally traverse THPVariable special stuff Py_VISIT(var->backward_hooks); Py_VISIT(var->post_accumulate_grad_hooks); - if (!var->cdata.unsafeIsBorrowed()) { - const auto& tensor = THPVariable_Unpack(var); - if (tensor.defined()) { - // WARNING: The grad_fn traversal logic is very subtle, if you change - // this, be very careful not to re-introduce this bug: - // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c - - // We ensure that we follow NOTE [ PyObject Traversal ] he by checking - // that this python object is the sole owner of the underlying Tensor and - // that this Tensor is the sole owner of its grad_fn. In this case, the - // only way to get a new reference to the grad_fn is by using this python - // object, which requires the GIL to be accessed. Note that this is only - // valid as long as user don't share non-owning references across - // different threads (which is crazy and should never be done). - auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); - if (tensor.use_count() == 1) { - if (autograd_meta) { - // Do NOT call grad_fn() here as that might trigger a recompute - const auto& grad_fn = autograd_meta->grad_fn_; - if (grad_fn && grad_fn.use_count() == 1) { - // All Node can have a pyobj (stored in "pyobj_") - Py_VISIT(grad_fn->pyobj()); - // PyNode are special as they also have an "obj" field - if (auto py_node_fn = dynamic_cast(grad_fn.get())) { - Py_VISIT(py_node_fn->obj); - } + const auto& tensor = THPVariable_Unpack(var); + if (tensor.defined()) { + // WARNING: The grad_fn traversal logic is very subtle, if you change + // this, be very careful not to re-introduce this bug: + // https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c + + // We ensure that we follow NOTE [ PyObject Traversal ] he by checking + // that this python object is the sole owner of the underlying Tensor and + // that this Tensor is the sole owner of its grad_fn. In this case, the + // only way to get a new reference to the grad_fn is by using this python + // object, which requires the GIL to be accessed. Note that this is only + // valid as long as user don't share non-owning references across + // different threads (which is crazy and should never be done). + auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor); + if (tensor.use_count() == 1) { + if (autograd_meta) { + // Do NOT call grad_fn() here as that might trigger a recompute + const auto& grad_fn = autograd_meta->grad_fn_; + if (grad_fn && grad_fn.use_count() == 1) { + // All Node can have a pyobj (stored in "pyobj_") + Py_VISIT(grad_fn->pyobj()); + // PyNode are special as they also have an "obj" field + if (auto py_node_fn = dynamic_cast(grad_fn.get())) { + Py_VISIT(py_node_fn->obj); } } } - if (autograd_meta) { - for (const auto& hook : torch::autograd::impl::hooks(tensor)) { - if (auto pyhook = - dynamic_cast(hook.get())) { - Py_VISIT(pyhook->dict); - } + } + if (autograd_meta) { + for (const auto& hook : torch::autograd::impl::hooks(tensor)) { + if (auto pyhook = dynamic_cast(hook.get())) { + Py_VISIT(pyhook->dict); } } } } - return 0; } @@ -2951,17 +2505,6 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { if (PyType_Type.tp_init(cls, args, kwargs) < 0) { return -1; } - // It is important for all three of these to be overridden correctly for the - // resurrection checks to properly happen. In particular, an older version - // was not overriding tp_clear here. This lead to the default subtype_clear - // running on the Tensor object (as only TensorBase tp_clear was custom), - // clearing the __dict__ field, before the TensorBase custom clear was called - // and would properly detect the resurrect. - // See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior - ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc; - ((PyTypeObject*)cls)->tp_traverse = - (traverseproc)THPVariable_subclass_traverse; - ((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear; // Don't do anything for the base Tensor class if (!THPVariableClass) { diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 82939211eb50a..4abb080b4160c 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -17,7 +17,7 @@ namespace py = pybind11; struct THPVariable { PyObject_HEAD // Payload - c10::MaybeOwned cdata; + at::Tensor cdata; // Hooks to be run on backwards pass (corresponds to Python attr // '_backwards_hooks', set by 'register_hook') PyObject* backward_hooks = nullptr; @@ -37,6 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass; TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); +TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var); TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var); inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { @@ -69,7 +70,7 @@ inline bool THPVariable_Check(PyObject* obj) { } inline const at::Tensor& THPVariable_Unpack(THPVariable* var) { - return *var->cdata; + return var->cdata; } inline const at::Tensor& THPVariable_Unpack(PyObject* obj) { diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index ed97dc4530eb4..00bdb91c36867 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -65,7 +65,9 @@ inline at::Tensor clone_obey_contract( .new_empty_strided_symint( variable.sym_sizes(), variable.sym_strides(), - variable.options().memory_format(std::nullopt)) + variable.options() + .memory_format(std::nullopt) + .dtype(new_grad.dtype())) .copy_(new_grad)); } else { // (2) diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 6e0494df5cf47..616b0fa0331bc 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -70,6 +70,10 @@ inline PyObject* wrap(const at::Tensor& tensor) { return THPVariable_Wrap(tensor); } +inline PyObject* wrap(at::Tensor&& tensor) { + return THPVariable_Wrap(std::move(tensor)); +} + inline PyObject* wrap(const at::Scalar& scalar) { return wrap(scalar_to_tensor(scalar)); } diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index a297a9f5ef425..05dbfdaa44325 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -197,6 +197,22 @@ TORCH_API std::unique_ptr& post_acc_grad_hooks( TORCH_API void create_cpp_hook( const at::TensorBase& /*self*/, bool is_retains_grad_hooks = false); + +inline bool is_tensor_stealable( + const at::Tensor& new_grad, + size_t num_expected_refs = 1) { + size_t use_count = new_grad.use_count(); + if (use_count <= num_expected_refs) { + return true; + } + if (use_count >= 2 && + new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) { + // The Python wrapper, if it exists, also has a reference to the Tensor. + num_expected_refs++; + } + return use_count <= num_expected_refs; +} + } // namespace impl //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -894,7 +910,7 @@ inline Variable make_variable( bool requires_grad = false, bool allow_tensor_metadata_change = true) { if (data.defined()) { - if (data.getIntrusivePtr().use_count() == 1 && + if (impl::is_tensor_stealable(data) && data.getIntrusivePtr()->unique_version()) { auto data_impl = data.unsafeReleaseIntrusivePtr(); data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); diff --git a/torch/csrc/utils/pyobject_preservation.cpp b/torch/csrc/utils/pyobject_preservation.cpp index 4f2d0a2507011..a652cbdb7aefd 100644 --- a/torch/csrc/utils/pyobject_preservation.cpp +++ b/torch/csrc/utils/pyobject_preservation.cpp @@ -1,19 +1,67 @@ #include -#include - -void clear_slots(PyTypeObject* type, PyObject* self) { - Py_ssize_t n = Py_SIZE(type); - PyMemberDef* mp = type->tp_members; - - for (Py_ssize_t i = 0; i < n; i++, mp++) { - if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) { - char* addr = (char*)self + mp->offset; - PyObject* obj = *(PyObject**)addr; - if (obj != nullptr) { - *(PyObject**)addr = nullptr; - Py_DECREF(obj); - } +#include +#include + +namespace torch::utils { + +using c10::intrusive_ptr_target; +using c10::impl::PyObjectSlot; + +void PyObjectPreservation::init_fresh_nonatomic( + intrusive_ptr_target* target, + PyObjectSlot* slot, + PyObject* pyobj) { + TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr); + TORCH_INTERNAL_ASSERT( + target->combined_refcount_.load(std::memory_order_relaxed) == + c10::detail::kUniqueRef); + + slot->pyobj_.store(pyobj, std::memory_order_relaxed); + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed); + target->combined_refcount_.store( + c10::detail::kHasPyObject | c10::detail::kUniqueRef, + std::memory_order_relaxed); +} + +PyObject* PyObjectPreservation::init_once( + intrusive_ptr_target* target, + PyObjectSlot* slot, + PyObject* pyobj) { + PyObject* expected = nullptr; + if (!slot->pyobj_.compare_exchange_strong( + expected, pyobj, std::memory_order_acq_rel)) { + TORCH_INTERNAL_ASSERT(expected != nullptr); + return expected; + } + + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_release); + + bool increfed = false; + auto combined = target->combined_refcount_.load(std::memory_order_relaxed); + do { + TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined)); + if (c10::detail::refcount(combined) > 1 && !increfed) { + // We need to incref the object to preserve the invariant that + // if refcount > 1, the c10 object holds a reference to the PyObject. + // This must happen before we set the kHasPyObject bit. + Py_INCREF(pyobj); + increfed = true; } + } while (!target->combined_refcount_.compare_exchange_weak( + combined, + combined | c10::detail::kHasPyObject, + std::memory_order_acq_rel, + std::memory_order_relaxed)); + + if (increfed && c10::detail::refcount(combined) == 1) { + // Fix up if refcount if we did the incref in a failed compare-exchange + Py_DECREF(pyobj); } + + return pyobj; } + +} // namespace torch::utils diff --git a/torch/csrc/utils/pyobject_preservation.h b/torch/csrc/utils/pyobject_preservation.h index 456095d7b7037..b060bc034b2c3 100644 --- a/torch/csrc/utils/pyobject_preservation.h +++ b/torch/csrc/utils/pyobject_preservation.h @@ -4,4 +4,28 @@ // This file contains utilities used for handling PyObject preservation -void clear_slots(PyTypeObject* type, PyObject* self); +namespace c10 { +class intrusive_ptr_target; +namespace impl { +struct PyObjectSlot; +} // namespace impl +} // namespace c10 + +namespace torch::utils { + +class PyObjectPreservation { + public: + // Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold + // a unique reference to `target`. + static void init_fresh_nonatomic( + c10::intrusive_ptr_target* target, + c10::impl::PyObjectSlot* slot, + PyObject* pyobj); + + static PyObject* init_once( + c10::intrusive_ptr_target* target, + c10::impl::PyObjectSlot* slot, + PyObject* pyobj); +}; + +} // namespace torch::utils From cda7604434c8fbfe849b1ff612cc286b7f1e5ff1 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Mon, 10 Nov 2025 11:16:31 -0800 Subject: [PATCH 304/651] [ez] Remove spammy deprecation log (#167470) " /packages/pytorch_latest_sixlib_conda/conda/lib/python3.12/site-packages/torch/_dynamo/variables/user_defined.py:1815: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead. return ctor(*args, **kwargs)" is too spammy Pull Request resolved: https://github.com/pytorch/pytorch/pull/167470 Approved by: https://github.com/tugsbayasgalan --- torch/_dynamo/variables/user_defined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 7709850d22d8b..ec378a5512a01 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1838,7 +1838,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # Handle specific pytree classes import torch.utils._pytree as pytree - if self.value_type is pytree.LeafSpec: + if isinstance(self.value, pytree.TreeSpec) and self.value.is_leaf(): # Create a new LeafSpec instance by calling the constructor codegen.add_push_null( lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec") From 7da82b84e28d44fd5d7aad7ec4a9213be0bb6dc7 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Nov 2025 10:22:03 -0800 Subject: [PATCH 305/651] Add torch::stable::Device (#166579) Prior to this PR, the IValue <-> StableIValue conversion for `DeviceObjType` (aka c10::Device) was to pack it into the leading bits of the StableIValue (which is a uint64_t) After this PR, the IValue <-> StableIValue conversion for `DeviceObjType` expects DeviceType to be packed into the upper 32 bits of StableIValue and DeviceIndex to be packed into the lower 32 bits Pull Request resolved: https://github.com/pytorch/pytorch/pull/166579 Approved by: https://github.com/janeyx99 --- .../libtorch_agnostic/csrc/kernel.cpp | 126 ++++++++++++++++++ .../libtorch_agnostic/ops.py | 78 +++++++++++ .../test/test_libtorch_agnostic.py | 55 ++++++++ torch/csrc/shim_common.cpp | 45 ++++++- torch/csrc/stable/c/shim.h | 7 + torch/csrc/stable/device.h | 4 + torch/csrc/stable/device_inl.h | 41 ++++++ torch/csrc/stable/device_struct.h | 107 +++++++++++++++ torch/csrc/stable/stableivalue_conversions.h | 103 ++++++++++++++ torch/headeronly/core/DeviceType.h | 2 + 10 files changed, 564 insertions(+), 4 deletions(-) create mode 100644 torch/csrc/stable/device.h create mode 100644 torch/csrc/stable/device_inl.h create mode 100644 torch/csrc/stable/device_struct.h diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 92a4af8b72733..b1c74a4b0f988 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -528,6 +529,131 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach); } +// Test functions for torch::stable::Device + +torch::stable::Device test_device_constructor( + bool is_cuda, + torch::stable::DeviceIndex index, + bool use_str) { + using torch::stable::Device; + using torch::stable::DeviceType; + + if (use_str) { + std::string device_str; + if (is_cuda) { + device_str = "cuda:" + std::to_string(index); + } else { + device_str = "cpu"; + } + return Device(device_str); + } else { + if (is_cuda) { + return Device(DeviceType::CUDA, index); + } else { + return Device(DeviceType::CPU); + } + } +} + +void boxed_test_device_constructor( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + torch::stable::Device res = test_device_constructor( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1]), + torch::stable::detail::to(stack[2])); + stack[0] = torch::stable::detail::from(res); +} + +bool test_device_equality(torch::stable::Device d1, torch::stable::Device d2) { + return d1 == d2; +} + +void boxed_test_device_equality( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + bool res = test_device_equality( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1])); + stack[0] = torch::stable::detail::from(res); +} + +torch::stable::Device test_device_set_index( + torch::stable::Device device, + torch::stable::DeviceIndex index) { + device.set_index(index); + return device; +} + +void boxed_test_device_set_index( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + torch::stable::Device res = test_device_set_index( + torch::stable::detail::to(stack[0]), + torch::stable::detail::to(stack[1])); + stack[0] = torch::stable::detail::from(res); +} + +torch::stable::DeviceIndex test_device_index(torch::stable::Device device) { + return device.index(); +} + +void boxed_test_device_index( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + torch::stable::DeviceIndex res = test_device_index( + torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); +} + +bool test_device_is_cuda(torch::stable::Device device) { + return device.is_cuda(); +} + +void boxed_test_device_is_cuda( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + bool res = test_device_is_cuda( + torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); +} + +bool test_device_is_cpu(torch::stable::Device device) { + return device.is_cpu(); +} + +void boxed_test_device_is_cpu( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + bool res = test_device_is_cpu( + torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); + m.def("test_device_equality(Device d1, Device d2) -> bool"); + m.def("test_device_set_index(Device device, DeviceIndex index) -> Device"); + m.def("test_device_index(Device device) -> DeviceIndex"); + m.def("test_device_is_cuda(Device device) -> bool"); + m.def("test_device_is_cpu(Device device) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_device_constructor", &boxed_test_device_constructor); + m.impl("test_device_equality", &boxed_test_device_equality); + m.impl("test_device_set_index", &boxed_test_device_set_index); + m.impl("test_device_index", &boxed_test_device_index); + m.impl("test_device_is_cuda", &boxed_test_device_is_cuda); + m.impl("test_device_is_cpu", &boxed_test_device_is_cpu); +} + // Test functions for torch::stable::accelerator APIs #ifdef LAE_USE_CUDA diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index e0e5cef216375..38eba12bb4690 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -375,3 +375,81 @@ def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]: return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default( t1, t2 ) + + +def test_device_constructor(is_cuda, index, use_str): + """ + Tests creating a Device from DeviceType and index, or from a string. + + Args: + is_cuda: bool - if True, creates CUDA device; if False, creates CPU device + index: int - device index + use_str: bool - if True, constructs from string; if False, constructs from DeviceType + + Returns: Device - A device with the specified type and index + """ + return torch.ops.libtorch_agnostic.test_device_constructor.default( + is_cuda, index, use_str + ) + + +def test_device_equality(d1, d2) -> bool: + """ + Tests Device equality operator. + + Args: + d1: Device - first device + d2: Device - second device + + Returns: bool - True if devices are equal + """ + return torch.ops.libtorch_agnostic.test_device_equality.default(d1, d2) + + +def test_device_set_index(device, index): + """ + Tests Device set_index() method. + + Args: + device: Device - device to modify + index: int - new device index + + Returns: Device - device with updated index + """ + return torch.ops.libtorch_agnostic.test_device_set_index.default(device, index) + + +def test_device_index(device) -> int: + """ + Tests Device index() method. + + Args: + device: Device - device to query + + Returns: int - device index + """ + return torch.ops.libtorch_agnostic.test_device_index.default(device) + + +def test_device_is_cuda(device) -> bool: + """ + Tests Device is_cuda() method. + + Args: + device: Device - device to check + + Returns: bool - True if device is CUDA + """ + return torch.ops.libtorch_agnostic.test_device_is_cuda.default(device) + + +def test_device_is_cpu(device) -> bool: + """ + Tests Device is_cpu() method. + + Args: + device: Device - device to check + + Returns: bool - True if device is CPU + """ + return torch.ops.libtorch_agnostic.test_device_is_cpu.default(device) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index e94c740861a11..ce5072b968d19 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -418,6 +418,61 @@ def test_make_tensor_clones_and_call_foreach(self, device): self.assertEqual(result[0], t1 * t1) self.assertEqual(result[1], t2 * t2) + @onlyCUDA + def test_device(self, device): + import libtorch_agnostic + + cuda_device = libtorch_agnostic.ops.test_device_constructor( + is_cuda=True, index=1, use_str=False + ) + self.assertEqual(cuda_device, torch.device("cuda:1")) + cuda_device = libtorch_agnostic.ops.test_device_constructor( + is_cuda=True, index=1, use_str=True + ) + self.assertEqual(cuda_device, torch.device("cuda:1")) + + self.assertEqual(libtorch_agnostic.ops.test_device_index(cuda_device), 1) + self.assertTrue( + libtorch_agnostic.ops.test_device_equality( + cuda_device, torch.device("cuda:1") + ) + ) + self.assertFalse( + libtorch_agnostic.ops.test_device_equality( + cuda_device, torch.device("cuda:0") + ) + ) + self.assertFalse(libtorch_agnostic.ops.test_device_is_cpu(cuda_device)) + self.assertTrue(libtorch_agnostic.ops.test_device_is_cuda(cuda_device)) + + cuda_0_device = libtorch_agnostic.ops.test_device_set_index(cuda_device, 0) + self.assertEqual(cuda_0_device, torch.device("cuda:0")) + + cpu_device = libtorch_agnostic.ops.test_device_constructor(False, 0, False) + self.assertEqual(cpu_device, torch.device("cpu")) + self.assertTrue( + libtorch_agnostic.ops.test_device_equality( + cpu_device, torch.device("cpu") + ) + ) + self.assertTrue(libtorch_agnostic.ops.test_device_is_cpu(cpu_device)) + self.assertFalse(libtorch_agnostic.ops.test_device_is_cuda(cpu_device)) + self.assertFalse( + libtorch_agnostic.ops.test_device_equality(cpu_device, cuda_device) + ) + + with self.assertRaisesRegex( + RuntimeError, "Device index 129 is out of range for int8_t" + ): + libtorch_agnostic.ops.test_device_constructor( + is_cuda=True, index=129, use_str=False + ) + + with self.assertRaisesRegex( + RuntimeError, "Device index 129 is out of range for int8_t" + ): + libtorch_agnostic.ops.test_device_set_index(cuda_device, 129) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 15b9b986a3463..302678192d9aa 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -90,8 +91,14 @@ static StableIValue from_ivalue( ivalue.toScalarType(), extension_build_version); } case c10::TypeKind::DeviceObjType: { - return torch::stable::detail::_from( - ivalue.toDevice(), extension_build_version); + // Pack device type and index into StableIValue in platform-independent + // format Lower 32 bits = device index, upper 32 bits = device type + const auto& device = ivalue.toDevice(); + uint64_t device_index_bits = + static_cast(static_cast(device.index())); + uint64_t device_type_bits = + static_cast(static_cast(device.type())) << 32; + return device_index_bits | device_type_bits; } case c10::TypeKind::LayoutType: { return torch::stable::detail::_from( @@ -175,8 +182,25 @@ static c10::IValue to_ivalue( stable_ivalue, extension_build_version)); } case c10::TypeKind::DeviceObjType: { - return c10::IValue(torch::stable::detail::_to( - stable_ivalue, extension_build_version)); + // Unpack device type and index from StableIValue + // Lower 32 bits = device index, upper 32 bits = device type + int32_t device_index = static_cast( + static_cast(stable_ivalue & 0xFFFFFFFF)); + c10::DeviceType device_type = + static_cast(static_cast( + static_cast((stable_ivalue >> 32) & 0xFFFFFFFF))); + TORCH_CHECK( + device_index >= std::numeric_limits::min() && + device_index <= std::numeric_limits::max(), + "Device index ", + device_index, + " is out of range for int8_t [", + static_cast(std::numeric_limits::min()), + ", ", + static_cast(std::numeric_limits::max()), + "]"); + return c10::IValue( + c10::Device(device_type, static_cast(device_index))); } case c10::TypeKind::LayoutType: { return c10::IValue(torch::stable::detail::_to( @@ -290,6 +314,19 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_impl( }); } +// Helper function to parse device string using c10::Device +// Returns device type and index +AOTI_TORCH_EXPORT AOTITorchError torch_parse_device_string( + const char* device_string, + uint32_t* out_device_type, + int32_t* out_device_index) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::Device device{std::string(device_string)}; + *out_device_type = static_cast(device.type()); + *out_device_index = static_cast(device.index()); + }); +} + // Version-aware variant of aoti_torch_library_impl that takes an // extension_build_version parameter for backward compatibility AOTI_TORCH_EXPORT AOTITorchError torch_library_impl( diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index ea6cea0726659..2545b36a5640c 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -65,6 +65,13 @@ torch_list_push_back(StableListHandle list_handle, StableIValue element); AOTI_TORCH_EXPORT AOTITorchError torch_delete_list(StableListHandle list_handle); +// Helper function to parse device string using c10::Device +// Returns device type and index via output parameters +AOTI_TORCH_EXPORT AOTITorchError torch_parse_device_string( + const char* device_string, + uint32_t* out_device_type, + int32_t* out_device_index); + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus diff --git a/torch/csrc/stable/device.h b/torch/csrc/stable/device.h new file mode 100644 index 0000000000000..223e3320a4fd3 --- /dev/null +++ b/torch/csrc/stable/device.h @@ -0,0 +1,4 @@ +#pragma once + +#include +#include diff --git a/torch/csrc/stable/device_inl.h b/torch/csrc/stable/device_inl.h new file mode 100644 index 0000000000000..8c9685f0d7da7 --- /dev/null +++ b/torch/csrc/stable/device_inl.h @@ -0,0 +1,41 @@ +#pragma once + +// This file implements device.h. We separated out the Device struct so that +// other files can depend on the Device struct (like stableivalue_conversions.h) +// and the implementations of the Device methods can depend on APIs in +// stableivalue_conversions.h without circular dependencies. + +#include +#include +#include +#include +#include +#include +#include + +#include + +HIDDEN_NAMESPACE_BEGIN(torch, stable) + +using DeviceType = torch::headeronly::DeviceType; +using DeviceIndex = torch::stable::accelerator::DeviceIndex; + +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + +inline Device::Device(const std::string& device_string) { + uint32_t device_type; + int32_t device_index; + + TORCH_ERROR_CODE_CHECK(torch_parse_device_string( + device_string.c_str(), &device_type, &device_index)); + + DeviceType dt = torch::stable::detail::to( + torch::stable::detail::from(device_type)); + DeviceIndex di = static_cast(device_index); + + *this = Device(dt, di); +} + +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + +HIDDEN_NAMESPACE_END(torch, stable) diff --git a/torch/csrc/stable/device_struct.h b/torch/csrc/stable/device_struct.h new file mode 100644 index 0000000000000..b422d62e30c58 --- /dev/null +++ b/torch/csrc/stable/device_struct.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +HIDDEN_NAMESPACE_BEGIN(torch, stable) + +using DeviceType = torch::headeronly::DeviceType; +using DeviceIndex = torch::stable::accelerator::DeviceIndex; + +// The torch::stable::Device class is an approximate copy of c10::Device. +// It has some slight modifications: +// 1. TORCH_INTERNAL_ASSERT_DEBUG_ONLY -> STD_TORCH_CHECK +// 2. Has a string constructor that uses a shim function +// 3. does not include some is_{device} variants that we can add later +// +// We chose to copy it rather than moving it to headeronly as +// 1. Device is < 8 bytes so the *Handle approach used for tensor doesn't make +// sense +// 2. c10::Device is not header-only due to its string constructor. +// +// StableIValue conversions handle conversion between c10::Device (in libtorch) +// and torch::stable::Device (in stable user extensions) + +class Device { + private: + DeviceType type_; + DeviceIndex index_ = -1; + + void validate() { + STD_TORCH_CHECK( + index_ >= -1, + "Device index must be -1 or non-negative, got ", + static_cast(index_)); + STD_TORCH_CHECK( + type_ != DeviceType::CPU || index_ <= 0, + "CPU device index must be -1 or zero, got ", + static_cast(index_)); + } + + public: + // Construct a stable::Device from a DeviceType and optional device index + // Default index is -1 (current device) + /* implicit */ Device(DeviceType type, DeviceIndex index = -1) + : type_(type), index_(index) { + validate(); + } + +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + // Construct a stable::Device from a string description + // The string must follow the schema: (cpu|cuda|...)[:] + // Defined in device_inl.h to avoid circular dependencies + /* implicit */ Device(const std::string& device_string); +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 + + // Copy and move constructors can be default + Device(const Device& other) = default; + Device(Device&& other) noexcept = default; + + // Copy and move assignment operators can be default + Device& operator=(const Device& other) = default; + Device& operator=(Device&& other) noexcept = default; + + // Destructor can be default + ~Device() = default; + + bool operator==(const Device& other) const noexcept { + return type() == other.type() && index() == other.index(); + } + + bool operator!=(const Device& other) const noexcept { + return !(*this == other); + } + + void set_index(DeviceIndex index) { + index_ = index; + } + + DeviceType type() const noexcept { + return type_; + } + + DeviceIndex index() const noexcept { + return index_; + } + + bool has_index() const noexcept { + return index_ != -1; + } + + bool is_cuda() const noexcept { + return type_ == DeviceType::CUDA; + } + + bool is_cpu() const noexcept { + return type_ == DeviceType::CPU; + } +}; + +HIDDEN_NAMESPACE_END(torch, stable) diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index 6885b1e4bdfeb..d69f41861ae94 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -3,7 +3,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -125,6 +127,37 @@ struct FromImpl { } }; +// Specialization for torch::headeronly::DeviceType => StableIValue +// Note that we call into the shim to translate between the user's +// DeviceType and libtorch's DeviceType, which can be different! +using torch::headeronly::DeviceType; +template <> +struct FromImpl { + static StableIValue call( + DeviceType val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + switch (val) { + case DeviceType::CPU: + return from(aoti_torch_device_type_cpu()); + case DeviceType::CUDA: + return from(aoti_torch_device_type_cuda()); + case DeviceType::Meta: + return from(aoti_torch_device_type_meta()); + case DeviceType::XPU: + return from(aoti_torch_device_type_xpu()); + case DeviceType::MPS: + return from(aoti_torch_device_type_mps()); + case DeviceType::PrivateUse1: + return from(aoti_torch_device_type_privateuse1()); + default: + TORCH_CHECK( + false, + "Not yet supported DeviceType, please file an issue describing your use case."); + } + } +}; + // Specialization for std::nullopt_t => StableIValue template <> struct FromImpl { @@ -233,6 +266,28 @@ struct FromImpl> { } }; +// Specialization for torch::stable::Device => StableIValue +// Pack the device type and index into a StableIValue in a platform-independent +// format. We use the shim representation for DeviceType (int32_t) for ABI +// stability. StableIValue layout: DeviceIndex in lower 32 bits, +// DeviceType (shim int32_t) in upper 32 bits +template <> +struct FromImpl { + static StableIValue call( + const torch::stable::Device& val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + // Convert DeviceType to shim representation (int32_t) + StableIValue device_type_shim = from(val.type()); + // Pack: lower 32 bits = device index, upper 32 bits = device type (shim) + uint64_t device_index_bits = + static_cast(static_cast(val.index())); + uint64_t device_type_bits = + static_cast(static_cast(device_type_shim)) << 32; + return device_index_bits | device_type_bits; + } +}; + // ============================================================================= // TO CONVERSIONS (StableIValue -> T) // ============================================================================= @@ -333,6 +388,36 @@ struct ToImpl { } }; +// Specialization for StableIValue => torch::headeronly::DeviceType +template <> +struct ToImpl { + static DeviceType call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + int32_t shim_devicetype = to(val); + if (shim_devicetype == aoti_torch_device_type_cpu()) { + return DeviceType::CPU; + } else if (shim_devicetype == aoti_torch_device_type_cuda()) { + return DeviceType::CUDA; + } else if (shim_devicetype == aoti_torch_device_type_meta()) { + return DeviceType::Meta; + } else if (shim_devicetype == aoti_torch_device_type_xpu()) { + return DeviceType::XPU; + } else if (shim_devicetype == aoti_torch_device_type_mps()) { + return DeviceType::MPS; + } else if (shim_devicetype == aoti_torch_device_type_privateuse1()) { + return DeviceType::PrivateUse1; + } else { + TORCH_CHECK( + false, + "Not yet supported DeviceType ", + std::to_string(shim_devicetype), + ", please file an issue describing your use case."); + } + } +}; + // Specialization for StableIValue => std::nullopt_t template <> struct ToImpl { @@ -415,6 +500,24 @@ struct ToImpl> { } }; +// Specialization for StableIValue => torch::stable::Device +// Unpack device type and index from StableIValue in platform-independent +// format. StableIValue layout: DeviceIndex in lower 32 bits, +// DeviceType (shim int32_t) in upper 32 bits +template <> +struct ToImpl { + static torch::stable::Device call( + StableIValue val, + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { + // Unpack: lower 32 bits = device index, upper 32 bits = device type (shim) + int32_t device_index = static_cast(val & 0xFFFFFFFF); + StableIValue device_type_shim = (val >> 32) & 0xFFFFFFFF; + DeviceType device_type = to(device_type_shim); + return torch::stable::Device(device_type, device_index); + } +}; + // ============================================================================= // end to helpers for converting between StableIValue and T // ============================================================================= diff --git a/torch/headeronly/core/DeviceType.h b/torch/headeronly/core/DeviceType.h index 980052b79c713..9db3ef2568d34 100644 --- a/torch/headeronly/core/DeviceType.h +++ b/torch/headeronly/core/DeviceType.h @@ -1,3 +1,5 @@ +#pragma once + // This is directly synchronized with caffe2/proto/caffe2.proto, but // doesn't require me to figure out how to get Protobuf headers into // ATen/core (which would require a lot more build system hacking.) From 6f0182495fc9a8ac36fa027659da82518a8efc15 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Nov 2025 10:22:04 -0800 Subject: [PATCH 306/651] Add stable::Tensor.device() (#166694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166694 Approved by: https://github.com/janeyx99 ghstack dependencies: #166579 --- .../libtorch_agnostic/csrc/kernel.cpp | 20 ++++++++++++++++++- .../libtorch_agnostic/ops.py | 12 +++++++++++ .../test/test_libtorch_agnostic.py | 18 +++++++++++++++++ torch/csrc/stable/tensor_inl.h | 11 ++++++++++ torch/csrc/stable/tensor_struct.h | 4 ++++ 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index b1c74a4b0f988..e31482dbd9386 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -529,6 +529,21 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach); } +// Test functions for torch::stable::Tensor device method + +torch::stable::Device test_tensor_device(torch::stable::Tensor tensor) { + return tensor.device(); +} + +void boxed_test_tensor_device( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + torch::stable::Device res = test_tensor_device( + torch::stable::detail::to(stack[0])); + stack[0] = torch::stable::detail::from(res); +} + // Test functions for torch::stable::Device torch::stable::Device test_device_constructor( @@ -637,7 +652,9 @@ void boxed_test_device_is_cpu( } STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { - m.def("test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); + m.def("test_tensor_device(Tensor t) -> Device"); + m.def( + "test_device_constructor(bool is_cuda, DeviceIndex index, bool use_str) -> Device"); m.def("test_device_equality(Device d1, Device d2) -> bool"); m.def("test_device_set_index(Device device, DeviceIndex index) -> Device"); m.def("test_device_index(Device device) -> DeviceIndex"); @@ -646,6 +663,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_tensor_device", &boxed_test_tensor_device); m.impl("test_device_constructor", &boxed_test_device_constructor); m.impl("test_device_equality", &boxed_test_device_equality); m.impl("test_device_set_index", &boxed_test_device_set_index); diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index 38eba12bb4690..c1fe842f9e8a0 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -215,6 +215,18 @@ def test_default_constructor(defined) -> bool: return torch.ops.libtorch_agnostic.test_default_constructor.default(defined) +def test_tensor_device(t): + """ + Tests Tensor device() method. + + Args: + t: Tensor - tensor to get device from + + Returns: Device - device of the tensor + """ + return torch.ops.libtorch_agnostic.test_tensor_device.default(t) + + def my_pad(t) -> Tensor: """ Pads the input tensor with hardcoded padding parameters. diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index ce5072b968d19..4c95011dce85d 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -473,6 +473,24 @@ def test_device(self, device): ): libtorch_agnostic.ops.test_device_set_index(cuda_device, 129) + @onlyCUDA + @deviceCountAtLeast(2) + def test_tensor_device(self, device): + import libtorch_agnostic + + t = torch.randn(2, 3) + self.assertEqual(libtorch_agnostic.ops.test_tensor_device(t), t.device) + + t_cuda = torch.randn(2, 3, device="cuda") + self.assertEqual( + libtorch_agnostic.ops.test_tensor_device(t_cuda), t_cuda.device + ) + + t_cuda_1 = torch.randn(2, 3, device="cuda:1") + self.assertEqual( + libtorch_agnostic.ops.test_tensor_device(t_cuda_1), t_cuda_1.device + ) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/stable/tensor_inl.h b/torch/csrc/stable/tensor_inl.h index 37582de201840..8eb69f1a63b74 100644 --- a/torch/csrc/stable/tensor_inl.h +++ b/torch/csrc/stable/tensor_inl.h @@ -22,4 +22,15 @@ inline ScalarType Tensor::scalar_type() const { torch::stable::detail::from(dtype)); } +inline Device Tensor::device() const { + int32_t device_type; + int32_t device_index; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(ath_.get(), &device_type)); + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(ath_.get(), &device_index)); + DeviceType extension_device_type = torch::stable::detail::to( + torch::stable::detail::from(device_type)); + return Device(extension_device_type, static_cast(device_index)); +} + HIDDEN_NAMESPACE_END(torch, stable) diff --git a/torch/csrc/stable/tensor_struct.h b/torch/csrc/stable/tensor_struct.h index 0d44ffd075170..e3f50ad26781c 100644 --- a/torch/csrc/stable/tensor_struct.h +++ b/torch/csrc/stable/tensor_struct.h @@ -10,6 +10,7 @@ #include #include +#include HIDDEN_NAMESPACE_BEGIN(torch, stable) @@ -192,6 +193,9 @@ class Tensor { // defined in tensor-inl.h to avoid circular dependencies ScalarType scalar_type() const; + // defined in tensor-inl.h to avoid circular dependencies + Device device() const; + // ============================================================================= // END of C-shimified TensorBase APIs // ============================================================================= From 232baa33b34b09c6cd4548e56910f2ab5b671e28 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Nov 2025 10:22:04 -0800 Subject: [PATCH 307/651] Redo add parallel_for to torch/csrc/stable (#166695) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166695 Approved by: https://github.com/malfet ghstack dependencies: #166579, #166694 --- .../libtorch_agnostic/csrc/kernel.cpp | 48 +++++++++++++++++++ .../libtorch_agnostic/ops.py | 12 +++++ .../test/test_libtorch_agnostic.py | 26 ++++++++++ torch/csrc/shim_common.cpp | 19 ++++++++ torch/csrc/stable/c/shim.h | 16 +++++++ torch/csrc/stable/ops.h | 25 ++++++++++ 6 files changed, 146 insertions(+) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index e31482dbd9386..85bc3c421ed16 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -761,3 +761,51 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { } #endif // LAE_USE_CUDA + +Tensor test_parallel_for(int64_t size, int64_t grain_size) { + AtenTensorHandle tensor_handle; + int64_t stride = 1; + + aoti_torch_empty_strided( + 1, + &size, + &stride, + aoti_torch_dtype_int64(), + aoti_torch_device_type_cpu(), + 0, + &tensor_handle); + + Tensor tensor(tensor_handle); + int64_t* data_ptr = reinterpret_cast(tensor.data_ptr()); + + torch::stable::zero_(tensor); + + // Use parallel_for to fill each element with its index + // If using a parallel path, the thread id is encoded in the upper 32 bits + torch::stable::parallel_for( + 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { + for (auto i = begin; i < end; i++) { + STD_TORCH_CHECK(i <= UINT32_MAX); + int thread_id = torch_get_thread_idx(); + data_ptr[i] = i | (static_cast(thread_id) << 32); + } + }); + + return tensor; +} + +void boxed_test_parallel_for( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + Tensor res = test_parallel_for(to(stack[0]), to(stack[1])); + stack[0] = from(res); +} + +STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { + m.def("test_parallel_for(int size, int grain_size) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { + m.impl("test_parallel_for", &boxed_test_parallel_for); +} diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index c1fe842f9e8a0..b9e53c1598f94 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -465,3 +465,15 @@ def test_device_is_cpu(device) -> bool: Returns: bool - True if device is CPU """ return torch.ops.libtorch_agnostic.test_device_is_cpu.default(device) + + +def test_parallel_for(size, grain_size) -> Tensor: + """ + Tests the parallel_for functionality by using it to fill a tensor with indices. + Args: + size: int - size of the tensor to create + grain_size: int - grain size for parallel_for + Returns: Tensor - a 1D int64 tensor where each element contains its index + (if multiple threads are used the threadid will be encoded in the upper 32 bits) + """ + return torch.ops.libtorch_agnostic.test_parallel_for.default(size, grain_size) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 4c95011dce85d..bfa71e0b13682 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -491,6 +491,32 @@ def test_tensor_device(self, device): libtorch_agnostic.ops.test_tensor_device(t_cuda_1), t_cuda_1.device ) + @onlyCPU + # TODO: Debug this: + # Dynamo failed to run FX node with fake tensors: + # call_function libtorch_agnostic.test_parallel_for.default(*(100, 10), **{}): + # got RuntimeError('libtorch_agnostic::test_parallel_for() expected at most + # 2 argument(s) but received 3 argument(s). + # Declaration: libtorch_agnostic::test_parallel_for(int size, int grain_size) -> Tensor') + @xfailIfTorchDynamo + def test_parallel_for(self, device): + import libtorch_agnostic + + num_threads = torch.get_num_threads() + size = 100 + grain_size = 10 + expected_num_threads_used = min( + (size + grain_size - 1) // grain_size, num_threads + ) + + result = libtorch_agnostic.ops.test_parallel_for(size, grain_size) + result_thread_ids = torch.unique(torch.bitwise_right_shift(result, 32)) + result_values = torch.bitwise_and(result, 0xFFFFFFFF) + expected = torch.arange(size, dtype=torch.int64) + + self.assertEqual(result_values, expected) + self.assertEqual(result_thread_ids, torch.arange(expected_num_threads_used)) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 302678192d9aa..9c8b6e5e72bf5 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -533,3 +534,21 @@ AOTI_TORCH_EXPORT AOTITorchError torch_call_dispatcher( } }); } + +AOTI_TORCH_EXPORT AOTITorchError torch_parallel_for( + int64_t begin, + int64_t end, + int64_t grain_size, + ParallelFunc func, + void* ctx) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::parallel_for( + begin, end, grain_size, [func, ctx](int64_t begin, int64_t end) { + func(begin, end, ctx); + }); + }); +} + +AOTI_TORCH_EXPORT uint32_t torch_get_thread_idx() { + return static_cast(at::get_thread_num()); +} diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index 2545b36a5640c..cf0b6d35f06ad 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -72,6 +72,22 @@ AOTI_TORCH_EXPORT AOTITorchError torch_parse_device_string( uint32_t* out_device_type, int32_t* out_device_index); +// Parallel utility APIs for stable ABI +// Function pointer type for parallel_for callback +// The callback receives begin and end indices for a range to process +typedef void (*ParallelFunc)(int64_t begin, int64_t end, void* ctx); + +AOTI_TORCH_EXPORT AOTITorchError torch_parallel_for( + int64_t begin, + int64_t end, + int64_t grain_size, + ParallelFunc func, + void* ctx); + +// Get the current thread index in a parallel region +// Returns 0 if not in a parallel region +AOTI_TORCH_EXPORT uint32_t torch_get_thread_idx(); + #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 #ifdef __cplusplus diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index d5fbba9fbbfd7..73442c75e61ea 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -273,6 +273,31 @@ inline torch::stable::Tensor clone(const torch::stable::Tensor& self) { // New ops should be added here if they use a brand new shim API +// Parallel utility wrapper that provides a stable interface to at::parallel_for +// This function has the same signature as at::parallel_for and allows stable +// ABI code to leverage PyTorch's parallel execution capabilities. +// +// The function f will be called with (begin, end) ranges to process in +// parallel. grain_size controls the minimum work size per thread for efficient +// parallelization. +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f) { + auto callback = [](int64_t cb_begin, int64_t cb_end, void* ctx) { + const F* func = static_cast(ctx); + (*func)(cb_begin, cb_end); + }; + TORCH_ERROR_CODE_CHECK(torch_parallel_for( + begin, + end, + grain_size, + callback, + const_cast(static_cast(&f)))); +} + #endif HIDDEN_NAMESPACE_END(torch, stable) From 69ab1f93e43fc346b7c93e0c5ec578b04d4cea89 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 10 Nov 2025 10:22:05 -0800 Subject: [PATCH 308/651] Add shim for at::get_num_threads (#167362) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167362 Approved by: https://github.com/janeyx99 ghstack dependencies: #166579, #166694, #166695 --- .../libtorch_agnostic/csrc/kernel.cpp | 17 ++++++++++++++++- .../libtorch_agnostic/ops.py | 10 ++++++++++ .../test/test_libtorch_agnostic.py | 8 ++++++++ torch/csrc/shim_common.cpp | 12 ++++++++++-- torch/csrc/stable/c/shim.h | 6 +++++- torch/csrc/stable/ops.h | 8 ++++++++ 6 files changed, 57 insertions(+), 4 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 85bc3c421ed16..96b6a17cf9187 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -786,7 +786,8 @@ Tensor test_parallel_for(int64_t size, int64_t grain_size) { 0, size, grain_size, [data_ptr](int64_t begin, int64_t end) { for (auto i = begin; i < end; i++) { STD_TORCH_CHECK(i <= UINT32_MAX); - int thread_id = torch_get_thread_idx(); + uint32_t thread_id; + torch_get_thread_idx(&thread_id); data_ptr[i] = i | (static_cast(thread_id) << 32); } }); @@ -802,10 +803,24 @@ void boxed_test_parallel_for( stack[0] = from(res); } +uint32_t test_get_num_threads() { + return torch::stable::get_num_threads(); +} + +void boxed_test_get_num_threads( + StableIValue* stack, + uint64_t num_args, + uint64_t num_outputs) { + uint32_t res = test_get_num_threads(); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("test_parallel_for(int size, int grain_size) -> Tensor"); + m.def("test_get_num_threads() -> int"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("test_parallel_for", &boxed_test_parallel_for); + m.impl("test_get_num_threads", &boxed_test_get_num_threads); } diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index b9e53c1598f94..59d8c17b68d77 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -477,3 +477,13 @@ def test_parallel_for(size, grain_size) -> Tensor: (if multiple threads are used the threadid will be encoded in the upper 32 bits) """ return torch.ops.libtorch_agnostic.test_parallel_for.default(size, grain_size) + + +def test_get_num_threads() -> int: + """ + Tests the get_num_threads functionality by returning the number of threads + for the parallel backend. + + Returns: int - the number of threads for the parallel backend + """ + return torch.ops.libtorch_agnostic.test_get_num_threads.default() diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index bfa71e0b13682..1149be388795a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -517,6 +517,14 @@ def test_parallel_for(self, device): self.assertEqual(result_values, expected) self.assertEqual(result_thread_ids, torch.arange(expected_num_threads_used)) + @onlyCPU + def test_get_num_threads(self, device): + import libtorch_agnostic + + num_threads = libtorch_agnostic.ops.test_get_num_threads() + expected_num_threads = torch.get_num_threads() + self.assertEqual(num_threads, expected_num_threads) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/shim_common.cpp b/torch/csrc/shim_common.cpp index 9c8b6e5e72bf5..1c4d9ce295a84 100644 --- a/torch/csrc/shim_common.cpp +++ b/torch/csrc/shim_common.cpp @@ -549,6 +549,14 @@ AOTI_TORCH_EXPORT AOTITorchError torch_parallel_for( }); } -AOTI_TORCH_EXPORT uint32_t torch_get_thread_idx() { - return static_cast(at::get_thread_num()); +AOTI_TORCH_EXPORT AOTITorchError +torch_get_thread_idx(uint32_t* out_thread_idx) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { *out_thread_idx = static_cast(at::get_thread_num()); }); +} + +AOTI_TORCH_EXPORT AOTITorchError +torch_get_num_threads(uint32_t* out_num_threads) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE( + { *out_num_threads = static_cast(at::get_num_threads()); }); } diff --git a/torch/csrc/stable/c/shim.h b/torch/csrc/stable/c/shim.h index cf0b6d35f06ad..0afa650fe2d7c 100644 --- a/torch/csrc/stable/c/shim.h +++ b/torch/csrc/stable/c/shim.h @@ -86,7 +86,11 @@ AOTI_TORCH_EXPORT AOTITorchError torch_parallel_for( // Get the current thread index in a parallel region // Returns 0 if not in a parallel region -AOTI_TORCH_EXPORT uint32_t torch_get_thread_idx(); +AOTI_TORCH_EXPORT AOTITorchError torch_get_thread_idx(uint32_t* out_thread_idx); + +// Get the number of threads for the parallel backend +AOTI_TORCH_EXPORT AOTITorchError +torch_get_num_threads(uint32_t* out_num_threads); #endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 73442c75e61ea..1a14cf9765094 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -298,6 +298,14 @@ inline void parallel_for( const_cast(static_cast(&f)))); } +// Get the number of threads for the parallel backend +// This provides a stable interface to at::get_num_threads +inline uint32_t get_num_threads() { + uint32_t num_threads; + TORCH_ERROR_CODE_CHECK(torch_get_num_threads(&num_threads)); + return num_threads; +} + #endif HIDDEN_NAMESPACE_END(torch, stable) From 47da714b8bf673caa0db43b295d70c8c434537b8 Mon Sep 17 00:00:00 2001 From: Nicolas Macchioni Date: Mon, 10 Nov 2025 22:51:03 +0000 Subject: [PATCH 309/651] [inductor][determinism] type errors + use odc to dump imc on exit (#167136) Summary: fix some type errors + instead of manually creating a filelock when dumping dcache's imc to file we simply use an odc (since this is the intended behavior of odc, anyways) Test Plan: ``` buck test fbcode//mode/opt caffe2/test/inductor:caching ``` Differential Revision: D86345594 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167136 Approved by: https://github.com/aorenste --- torch/_inductor/runtime/caching/interfaces.py | 43 +++++++++++-------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 03d2957493679..d0c1011200e43 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -7,7 +7,7 @@ from ast import literal_eval from enum import Enum from functools import partial, wraps -from logging import DEBUG, getLogger, Logger +from logging import DEBUG, getLogger, INFO, Logger from os import PathLike from pathlib import Path from threading import Lock @@ -15,8 +15,6 @@ from typing import Any, TYPE_CHECKING, TypeAlias from typing_extensions import override -from filelock import FileLock - from . import config, context, exceptions, implementations as impls, locks @@ -329,10 +327,10 @@ def insert( def record( self, ischema: context.IsolationSchema | None = None, - custom_params_encoder: Callable[P, Any] | None = None, - custom_result_encoder: Callable[[R], Any] | None = None, - custom_result_decoder: Callable[[Any], R] | None = None, - ) -> Callable[[Callable[P, R]], Callable[P, R]]: + custom_params_encoder: Callable[..., Any] | None = None, + custom_result_encoder: Callable[..., Any] | None = None, + custom_result_decoder: Callable[..., ...] | None = None, + ) -> Callable[[Callable[..., ...]], Callable[..., ...]]: if custom_result_encoder and not custom_result_decoder: raise exceptions.CustomResultDecoderRequiredError( "Custom result encoder provided without custom result decoder." @@ -506,16 +504,22 @@ def __init__(self) -> None: super().__init__() self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() - if fpath := os.environ.get("TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE"): - # pyrefly: ignore [bad-assignment] - flock: FileLock = FileLock(str(fpath) + ".lock") - with locks._acquire_flock_with_timeout(flock): - with open(fpath) as fp: - dump_for_pre_population: dict[str, str] = json.load(fp) - for key_r, value_r in dump_for_pre_population.items(): - key: bytes = literal_eval(key_r) - value: bytes = literal_eval(value_r) - self._imc._memory[key] = value + if fpath_str := os.environ.get( + "TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE" + ): + fpath: Path = Path(fpath_str) + fpath_parent: PathLike[str] = fpath.parent + if fpath.is_file(): + odc: impls._OnDiskCacheImpl = impls._OnDiskCacheImpl( + sub_dir=fpath_parent + ) + with odc.lock(): + with open(fpath) as fp: + dump_for_pre_population: dict[str, str] = json.load(fp) + for key_r, value_r in dump_for_pre_population.items(): + key: bytes = literal_eval(key_r) + value: bytes = literal_eval(value_r) + self._imc._memory[key] = value if config.STRICTLY_PRE_POPULATED_DETERMINISM: # we'll never need a synchronization cache if we're in strictly pre-populated mode, @@ -578,7 +582,7 @@ def _dump_imc_to_disk(self) -> Path | None: for key, value in existing_dump.items(): if key not in to_dump: to_dump[key] = value - else: + elif to_dump[key] != value: raise exceptions.DeterministicCachingIMCDumpConflictError from None w_fp = open(fpath, "w") @@ -586,6 +590,9 @@ def _dump_imc_to_disk(self) -> Path | None: assert w_fp is not None try: json.dump(to_dump, w_fp, indent=4) + logger.log( + INFO, "Dumped deterministic cache memoization to %s", fpath + ) finally: w_fp.close() From de773364be041ca7fd2dcaf35ca15c093fc9370b Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 7 Nov 2025 11:56:50 -0800 Subject: [PATCH 310/651] Support AC in default partitioner when functionalization is enabled (#166610) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166610 Approved by: https://github.com/SherlockNoMad ghstack dependencies: #166536 --- .../distributed/tensor/test_dtensor_export.py | 2 - test/dynamo/test_activation_checkpointing.py | 267 +++++++++++++--- test/functorch/test_aotdispatch.py | 15 +- test/higher_order_ops/test_local_map.py | 4 +- .../_aot_autograd/graph_capture_wrappers.py | 5 + torch/_functorch/partitioners.py | 300 ++++++++++-------- 6 files changed, 399 insertions(+), 194 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index b9749e3bc4e23..2a8e00709dd6c 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import contextlib -import unittest import torch import torch.distributed as dist @@ -372,7 +371,6 @@ def test_export_parallelize_module_with_dtensor_input( # aot_export_joint_with_descriptors on strict-exported exported_program.module() # is producing a joint graph with backward region missing - @unittest.expectedFailure def test_strict_export_parallelize_module_with_dtensor_input(self): self._run_test(strict_export_and_aot_export_joint_with_descriptors) diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 6e1f45c166984..252ddb204f15a 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -15,7 +15,7 @@ import torch.distributed as dist import torch.nn as nn import torch.utils.checkpoint -from functorch.compile import min_cut_rematerialization_partition +from functorch.compile import default_partition, min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch._dynamo.testing import ( AotEagerAndRecordGraphs, @@ -24,7 +24,7 @@ ) from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu +from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, skipIfHpu from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.two_tensor import TwoTensor @@ -281,7 +281,14 @@ def runtime_wrapper(*runtime_args): run(export_compiler) - def test_tags_function(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -297,11 +304,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_via_global_checkpoint(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_via_global_checkpoint(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -316,17 +334,28 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_function_with_kwargs(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_function_with_kwargs(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) def fn(x, y): return torch.utils.checkpoint.checkpoint( - gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False + gn, torch.sin(x), y, use_reentrant=False ) x = torch.randn(4, 4, device=device, requires_grad=True) @@ -336,11 +365,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=3, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_sequential_layers(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_sequential_layers(self, device, partition_fn): def gn(x): x = x.cos() for _ in range(3): @@ -361,11 +401,22 @@ def fn(x): freqs=[2, 18], ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_multiple_checkpoints(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_multiple_checkpoints(self, device, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) @@ -383,11 +434,22 @@ def fn(x, y): bw_compiler = functools.partial( count_ops, freq=6, op=torch.ops.aten.mm.default ) # mm recomputed in the bwd - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x, y) @requires_cuda_and_triton - def test_tags_module(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_module(self, device, partition_fn): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -411,11 +473,22 @@ def fn(x): bw_compiler = functools.partial( count_ops, freq=1, op=torch.ops.aten.sigmoid.default ) - backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) + backend = aot_autograd( + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + partition_fn=partition_fn, + ) self._validate(fn, backend, x) @requires_cuda_and_triton - def test_tags_decomps(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_tags_decomps(self, device, partition_fn): # Ensures that tags are passed on through decompositions as well class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -443,6 +516,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, + partition_fn=partition_fn, decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), @@ -702,7 +776,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_recompute(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_recompute(self, device, partition_fn): def context_fn_must_recompute_mm(): must_recompute_list = [ torch.ops.aten.mm.default, @@ -723,9 +804,9 @@ def context_fn_no_recompute_mm(): ), ) - def _test(context_fn, bw_compiler): + def _test(context_fn, bw_compiler, partition_fn): def gn(x): - return torch.sigmoid(torch.matmul(x, x)) + return torch.cos(torch.sin(torch.matmul(x, x) @ x)) def fn(x): return torch.utils.checkpoint.checkpoint( @@ -739,14 +820,14 @@ def fn(x): fw_compiler = functools.partial( count_ops, - freq=1, + freq=2, op=torch.ops.aten.mm.default, ) backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x) @@ -754,17 +835,19 @@ def fn(x): context_fn=context_fn_must_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) + freq=6, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 2 + 2 * 2 = 6) op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) _test( context_fn=context_fn_no_recompute_mm, bw_compiler=functools.partial( count_ops, - freq=2, # 2 bwd mm ops per fwd matmul + freq=4, # 2 bwd mm ops per fwd matmul op=torch.ops.aten.mm.default, ), + partition_fn=partition_fn, ) def test_sac_with_partial_context_fn(self): @@ -801,7 +884,16 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_must_not_recompute_gemm(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_must_not_recompute_gemm( + self, device, partition_fn + ): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -841,15 +933,22 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) def test_compile_selective_checkpoint_must_not_recompute_gemm_no_functionalization( - self, device + self, device, partition_fn ): def selective_checkpointing_context_fn(): no_recompute_list = [ @@ -889,7 +988,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, disable_functionalization=True, ) self._validate(fn, backend, x, y) @@ -897,7 +996,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_triton_kernel(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_triton_kernel(self, device, partition_fn): # Copy of the above test, but make sure that having a triton kernel in the # region does not error. def add_one(x): @@ -957,14 +1063,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_tensor_subclass(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_tensor_subclass(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1007,14 +1120,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_custom_rule(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_custom_rule(self, device, partition_fn): def _get_custom_policy(meta): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1072,14 +1192,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_partial_ctx_fn(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_partial_ctx_fn(self, device, partition_fn): def selective_checkpointing_context_fn(no_recompute_list): return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) @@ -1118,14 +1245,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_outplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_outplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1163,14 +1297,21 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_list_ops(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_list_ops(self, device, partition_fn): def selective_checkpointing_context_fn(): # recompute everything no_recompute_list = [] @@ -1206,7 +1347,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1217,7 +1358,14 @@ def fn(x, y): "requires TorchDispatchMode + torch.compile work to complete" ) @requires_cuda_and_triton - def test_compile_selective_checkpoint_inplace_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_inplace_op(self, device, partition_fn): def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, @@ -1257,7 +1405,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) self._validate(fn, backend, x, y) self._compare_orig_and_checkpointed_fns(gn, fn, x, y) @@ -1265,7 +1413,14 @@ def fn(x, y): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") @torch._inductor.config.patch(fallback_random=True) - def test_compile_selective_checkpoint_random_op(self, device): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_random_op(self, device, partition_fn): for preserve_rng_state in [True, False]: def selective_checkpointing_context_fn(): @@ -1312,7 +1467,7 @@ def fn(x): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, @@ -1324,7 +1479,14 @@ def fn(x): @requires_cuda_and_triton @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") - def test_compile_selective_checkpoint_invalid_context(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_invalid_context(self, partition_fn): def gn(x, y): return torch.sigmoid(torch.matmul(x, y)) * y @@ -1353,7 +1515,7 @@ def fn(x, y): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) with self.assertRaisesRegex( Exception, "must generate a tuple of two `TorchDispatchMode`s" @@ -1362,7 +1524,14 @@ def fn(x, y): @requires_cuda_and_triton @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - def test_compile_selective_checkpoint_parametrization(self): + @parametrize( + "partition_fn", + [ + min_cut_rematerialization_partition, + default_partition, + ], + ) + def test_compile_selective_checkpoint_parametrization(self, partition_fn): def sac_policy(): def _recomp_policy(): def _custom_policy(ctx, func, *args, **kwargs): @@ -1425,7 +1594,9 @@ def reset_parameters(self): bw_compiler = functools.partial( count_ops, freqs=[ - 2, # 1 from mul recompute, 1 from mul backward + # 1 from mul recompute, 1 from mul backward + # w/o CSE, we have one extra mul + 3 if partition_fn is default_partition else 2, 1, ], ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], @@ -1434,7 +1605,7 @@ def reset_parameters(self): backend = aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, - partition_fn=min_cut_rematerialization_partition, + partition_fn=partition_fn, ) model = MLPModule() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..c452f18e95d75 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2640,7 +2640,7 @@ def backward(ctx, grad_output): return grad_output * x, grad_output * x def f(a, b): - return FwBwMutation.apply(a, b) + return FwBwMutation.apply(a, b).sin_().clone() inps = [ torch.ones(3, 3, requires_grad=True), @@ -2689,17 +2689,22 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None _foreach_mul__1 = torch.ops.aten._foreach_mul_.ScalarList([add], [3]); _foreach_mul__1 = None mul = torch.ops.aten.mul.Tensor(add, primals_1); primals_1 = None - return (mul, add)""", + clone = torch.ops.aten.clone.default(mul) + sin_ = torch.ops.aten.sin_.default(mul); mul = None + clone_1 = torch.ops.aten.clone.default(sin_); sin_ = None + return (clone_1, add, clone)""", ) # important bit: there is 1 mutation in the bw self.assertExpectedInline( bw_graph[0].code.strip(), """\ -def forward(self, add, tangents_1): +def forward(self, add, clone, tangents_1): + cos = torch.ops.aten.cos.default(clone); clone = None + mul_1 = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None _foreach_mul__2 = torch.ops.aten._foreach_mul_.ScalarList([add], [4]); _foreach_mul__2 = None - mul_1 = torch.ops.aten.mul.Tensor(tangents_1, add); tangents_1 = add = None - return (mul_1, None)""", + mul_2 = torch.ops.aten.mul.Tensor(mul_1, add); mul_1 = add = None + return (mul_2, None)""", ) def test_fw_bw_mutation_no_functionalization2(self): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index fbb21633260e7..10ec8e7444b6e 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -927,8 +927,8 @@ def inputs_fn(): op="call_function", target=torch.ops.aten.mm.default ) self.assertEqual(len(mm_nodes), 4) - self.assertNotIn("partitioner_tag", mm_nodes[0].meta) - self.assertNotIn("partitioner_tag", mm_nodes[1].meta) + self.assertEqual(mm_nodes[0].meta["partitioner_tag"], "is_forward") + self.assertEqual(mm_nodes[1].meta["partitioner_tag"], "is_forward") self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward") self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0) diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index bc4dc87ddeced..4a1e865930d41 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -27,6 +27,7 @@ from torch._prims_common import CUDARngStateHelper from torch.fx.experimental.proxy_tensor import ( _proxy_tensor_disable_update_tensor_tracker, + get_proxy_mode, maybe_disable_thunkify, maybe_enable_thunkify, ) @@ -295,6 +296,10 @@ def inner_fn( (outs, tangent_mask), (outs_descs, _) = call_and_expect_output_descs( fn, primals ) + mode = get_proxy_mode() + assert mode is not None + for node in mode.tracer.graph.nodes: + node.meta["partitioner_tag"] = "is_forward" # TODO: I think this hook can also be eliminated now if joint_fn_handle and joint_fn_handle.post_forward: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index e7f8075b0281e..30374d85b5faa 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -51,6 +51,7 @@ ) from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput +from ._aot_autograd.functional_utils import assert_functional_graph from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems @@ -297,6 +298,10 @@ def _has_tag_is_backward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "is_backward" +def _has_tag_is_forward(node: fx.Node) -> bool: + return node.meta.get("partitioner_tag", None) == "is_forward" + + def _has_tag_must_be_in_forward(node: fx.Node) -> bool: return node.meta.get("partitioner_tag", None) == "must_be_in_forward" @@ -1021,105 +1026,95 @@ def default_partition( Returns: Returns the generated forward and backward Fx graph modules. """ - if has_recomputable_ops(joint_module): - return min_cut_rematerialization_partition( - joint_module, - _joint_inputs, - num_fwd_outputs=num_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) + # Respect the original placement of ops rather than rely on dataflow. + forward_nodes = [] + last_node = None + for node in joint_module.graph.nodes: + if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node): + last_node = node + assert last_node is not None + for node in joint_module.graph.nodes: + if not _is_tangent(node): + forward_nodes.append(node) + if node is last_node: + break forward_node_names = OrderedSet( - node.name for node in forward_only_graph.nodes if node.op != "output" + node.name for node in forward_nodes if node.op != "output" ) - order = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + assert_functional_graph(joint_module.graph) + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True) + + if not config.unsafe_allow_optimization_of_collectives: + force_save_collectives(joint_module) + + force_save_bw_mutation_src(joint_module) + + if static_lifetime_input_indices is None: + static_lifetime_input_indices = [] + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) + saved_values = [] saved_sym_nodes = [] - def is_mutated_later_in_fw(node): - if _has_tag_is_backward(node): - return False - tensor_arg_aliases = [ - x - for x in node.args - if isinstance(x, fx.Node) - and "val" in x.meta - and isinstance(x.meta["val"], torch.Tensor) - ] - while len(tensor_arg_aliases) > 0: - a = tensor_arg_aliases.pop() - for u in a.users: - if not isinstance(u.target, torch._ops.OpOverload): - continue - # If we witness a mutation on our node later, and that mutation is not "must be in backward", - # then our node needs to be computed in the forward (otherwise we will compute it on the mutated values) - if ( - # one of the args was mutated - u.target._schema.is_mutable - # and the mutation happens "later" - and order[u] > order[node] - # and the mutation happened during the forward - and not (_has_tag_is_backward(u) or _has_tag_must_be_in_backward(u)) - ): - for idx, alias_info in enumerate(u.target._schema.arguments): - if alias_info.is_write and u.args[idx] is a: - return True - elif u.target.is_view: - tensor_arg_aliases.append(u) - return False - for node in joint_module.graph.nodes: if node.name not in forward_node_names: - # if a node isn't "required" to be in the forward, but any of its arguments - # are later mutated in the forward, then it must have been run in the forward - # (if not, and the node's arg was saved for backward, we would have mutated a saved value) - # NB: doesn't handle nodes where the input is a list of tensors and one of those tensors is later mutated - if is_mutated_later_in_fw(node): - saved_values.append(node) continue if is_sym_node(node): # Symints must be kept separate from tensors so that PythonFunction only calls # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes.append(node) - elif ( + continue + if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: + saved_values.append(node) + continue + if node.is_impure(impure_random=False) and node.op not in ( + "placeholder", + "output", + ): + # See is_impure in torch/fx/node.py + assert not graph_has_recomputable_ops, ( + "Trying to apply AC on a graph with impure op", + node, + node.target, + ) + saved_values.append(node) + continue + backward_usages = [n for n in node.users if n.name not in forward_node_names] + if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + continue + if ( "tensor_meta" not in node.meta and node.op == "call_function" and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) ): - # Since we can't save tuple of tensor values, we need to flatten out what we're saving - users = node.users - assert all(user.target is operator.getitem for user in users) - saved_values.extend(users) - else: - backward_usages = [ - n for n in node.users if n.name not in forward_node_names - ] - if "tensor_meta" in node.meta and all( - is_sym_node(n) for n in backward_usages - ): - # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, - # and not the actual tensor data, - # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. - # - # Note that saving the tensor could also cause compilation problems: - # If the user mutated an input in the forward and uses its sizes/strides in the backward, - # then we would be obligated to clone the input before saving it to appease autograd. - # (This is how we originally found this bug). - saved_sym_nodes.extend(backward_usages) - else: - saved_values.append(node) + assert all(user.target == operator.getitem for user in node.users) + continue + if not must_recompute(node): + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) - return _extract_fwd_bwd_modules( + if config._sync_decision_cross_ranks: + saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) + + if static_lifetime_input_nodes is None: + static_lifetime_input_nodes = node_info.static_lifetime_input_nodes + fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, @@ -1127,6 +1122,24 @@ def is_mutated_later_in_fw(node): static_lifetime_input_nodes=static_lifetime_input_nodes, ) + if graph_has_recomputable_ops: + if graph_has_recomputable_rng_ops: + fw_module, bw_module = functionalize_rng_ops( + joint_module, fw_module, bw_module, len(saved_sym_nodes) + ) + bw_module = reordering_to_mimic_autograd_engine(bw_module) + + # raise all getitem ops to as early as possible + # this is helpful for memory, especially in the case of aot_eager backend + fw_module = raise_getitems(fw_module) + bw_module = raise_getitems(bw_module) + + fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False) + if len(node_info.required_bw_nodes) > 0: + bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True) + + return fw_module, bw_module + INT_INF = int(1e6) @@ -1621,7 +1634,9 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: break -def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: +def cleanup_recompute_tags( + joint_module: fx.GraphModule, *, is_default_partition: bool +) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -1658,6 +1673,16 @@ def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: # Solution: check whether `out` has a backward hook, and if so, intentionally save `out` # in forward graph outputs. With this, we can break the above circular dependency. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + elif ( + "ac_graph_id" not in node.meta + and any(must_recompute(user) for user in node.users) + and is_default_partition + ): + # This node is not part of the AC region and a user is marked as recompute. + # This means it's an input to the AC region and we should save it. + # For ease of landing, gate this to default partitioner only, but we should think + # about flipping the switch in general as well. + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE return joint_module @@ -2765,6 +2790,59 @@ def thread_graphsafe_rng_from_hops(module, is_backward): return module +def classify_nodes(joint_module, static_lifetime_input_indices, num_fwd_outputs): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + elif _must_be_in_backward(node): + required_bw_nodes.add(node) + + if node in required_bw_nodes: + required_bw_nodes.update(node.users) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( + _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" + ) + required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + ) + unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + ) + static_lifetime_input_nodes = OrderedSet( + p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices + ) + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, + required_fw_nodes, + required_bw_nodes, + unclaimed_nodes, + fw_order, + static_lifetime_input_nodes, + ) + + def min_cut_rematerialization_partition( joint_module: fx.GraphModule, _joint_inputs, @@ -2813,68 +2891,16 @@ def min_cut_rematerialization_partition( graph_has_recomputable_ops = has_recomputable_ops(joint_module) graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) + joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False) if not config.unsafe_allow_optimization_of_collectives: force_save_collectives(joint_module) force_save_bw_mutation_src(joint_module) - def classify_nodes(joint_module, static_lifetime_input_indices): - name_to_node = get_name_to_node(joint_module.graph) - required_bw_nodes: OrderedSet[fx.Node] = OrderedSet() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - elif _must_be_in_backward(node): - required_bw_nodes.add(node) - - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = ( - _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward" - ) - required_fw_nodes: OrderedSet[fx.Node] = OrderedSet( - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - ) - unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet( - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - ) - static_lifetime_input_nodes = OrderedSet( - p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices - ) - fw_cnt = 0 - fw_order = {} - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - fw_order[node] = fw_cnt - fw_cnt += 1 - return NodeInfo( - inputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - fw_order, - static_lifetime_input_nodes, - ) - if static_lifetime_input_indices is None: static_lifetime_input_indices = [] - node_info = classify_nodes(joint_module, static_lifetime_input_indices) + node_info = classify_nodes( + joint_module, static_lifetime_input_indices, num_fwd_outputs + ) # networkx blows up on graphs with no required backward nodes # Since there's nothing to partition anyway, and the default partitioner can "handle" From 8ef4099313f8594cebda9bbeba0f10874f5b0e1f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 10 Nov 2025 23:25:22 +0000 Subject: [PATCH 311/651] Revert "Add min/max support for barebones uint types (#166813)" This reverts commit 9ffc480c5a928eaccb4ac0e1755a1c596674d884. Reverted https://github.com/pytorch/pytorch/pull/166813 on behalf of https://github.com/jeanschmidt due to It was reverted internally 6 days ago, but not reverted on OSS, this is causing conflicts ([comment](https://github.com/pytorch/pytorch/pull/166813#issuecomment-3514328895)) --- aten/src/ATen/cuda/NumericLimits.cuh | 31 ------------------- .../ATen/native/cpu/ReduceAllOpsKernel.cpp | 13 ++++---- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 28 ++++++++--------- .../ATen/native/cpu/TensorCompareKernel.cpp | 13 ++++---- .../ATen/native/cuda/ReduceAMinMaxKernel.cu | 13 ++++---- .../ATen/native/cuda/ReduceMaxValuesKernel.cu | 17 +++++----- .../ATen/native/cuda/ReduceMinValuesKernel.cu | 13 ++++---- .../_internal/common_methods_invocations.py | 14 ++++----- 8 files changed, 52 insertions(+), 90 deletions(-) diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh index ebbc004382380..7081e94837caa 100644 --- a/aten/src/ATen/cuda/NumericLimits.cuh +++ b/aten/src/ATen/cuda/NumericLimits.cuh @@ -55,14 +55,6 @@ struct numeric_limits { static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } }; -template <> -struct numeric_limits { - static inline __host__ __device__ uint16_t lowest() { return 0; } - static inline __host__ __device__ uint16_t max() { return UINT16_MAX; } - static inline __host__ __device__ uint16_t lower_bound() { return 0; } - static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; } -}; - template <> struct numeric_limits { static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } @@ -71,14 +63,6 @@ struct numeric_limits { static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } }; -template <> -struct numeric_limits { - static inline __host__ __device__ uint32_t lowest() { return 0; } - static inline __host__ __device__ uint32_t max() { return UINT32_MAX; } - static inline __host__ __device__ uint32_t lower_bound() { return 0; } - static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; } -}; - template <> struct numeric_limits { static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } @@ -87,21 +71,6 @@ struct numeric_limits { static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } }; -template <> -struct numeric_limits { -#ifdef _MSC_VER - static inline __host__ __device__ uint64_t lowest() { return 0; } - static inline __host__ __device__ uint64_t max() { return _UI64_MAX; } - static inline __host__ __device__ uint64_t lower_bound() { return 0; } - static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; } -#else - static inline __host__ __device__ uint64_t lowest() { return 0; } - static inline __host__ __device__ uint64_t max() { return UINT64_MAX; } - static inline __host__ __device__ uint64_t lower_bound() { return 0; } - static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; } -#endif -}; - template <> struct numeric_limits { #ifdef _MSC_VER diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index c5dbf05039eb1..c7eaa802af125 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -5,7 +5,6 @@ #include #include -#include #include #include #include @@ -79,12 +78,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return minimum(a, b); }); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); + }); } } @@ -104,12 +103,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); }); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); + }); } } @@ -200,7 +199,7 @@ void aminmax_allreduce_kernel( } ); } else { - AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] { using Vec = Vectorized>; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( @@ -215,7 +214,7 @@ void aminmax_allreduce_kernel( [=](Vec a, Vec b) -> Vec { return minimum(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); } ); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); + }); } } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 3bad49a32d98c..2e62936501948 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -3,7 +3,6 @@ #include #include -#include #include #include #include @@ -348,35 +347,34 @@ struct MinValuesOps: public at::native::MinOps { }; void min_values_kernel_impl(TensorIterator& iter) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - if (iter.dtype() == kLong || iter.dtype() == kUInt64) { - AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); - }), kLong, kUInt64); + if (iter.dtype() == kLong) { + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + using scalar_t = int64_t; + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); return; } - AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, static_cast(upper_bound())); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void max_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); }, [](Vectorized a, Vectorized b) { return maximum(a, b); }, lower_bound()); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void argmax_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 22c85735ad6ab..c479e1610cbeb 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include #include @@ -107,7 +106,7 @@ void min_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -129,7 +128,7 @@ void min_kernel_impl( *indice_data = index; } ); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); + }); } void max_kernel_impl( @@ -140,7 +139,7 @@ void max_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -162,7 +161,7 @@ void max_kernel_impl( *indice_data = index; } ); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); + }); } void aminmax_kernel( @@ -187,7 +186,7 @@ void aminmax_kernel( return; } - AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] { compare_base_kernel(min_result, max_result, self, wrap_dim, keepdim, [&] ( scalar_t* min_result_data, scalar_t* max_result_data, const scalar_t* self_data, auto self_dim_stride) { @@ -210,7 +209,7 @@ void aminmax_kernel( *max_result_data = max_number; } ); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half); + }); } void where_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu index 0b7823863047a..cdd5daab2d983 100644 --- a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -1,6 +1,5 @@ #define TORCH_ASSERT_NO_OPERATORS #include -#include #include #include #include @@ -29,22 +28,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { } void aminmax_allreduce_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_V2( - iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { _min_max_values_kernel_cuda_impl(iter); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void aminmax_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_V2( - iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { gpu_reduce_kernel( iter, MinMaxOps{}, thrust::pair( at::numeric_limits::upper_bound(), at::numeric_limits::lower_bound())); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu index bcbc4c0359943..e8d1e88ebb3ec 100644 --- a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -1,6 +1,5 @@ #define TORCH_ASSERT_NO_OPERATORS #include -#include #include #include #include @@ -34,27 +33,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) { } void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_V2( - iter.dtype(), "max_values_cuda", AT_WRAP([&]() { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { max_values_kernel_cuda_impl(iter); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void max_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_V2( - iter.input_dtype(), "max_cuda", AT_WRAP([&]() { + AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { gpu_reduce_kernel( iter, MaxOps{}, thrust::pair( at::numeric_limits::lower_bound(), 0)); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { max_values_kernel_cuda_impl(iter); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu index 0006a24dbc466..e01ca6c88ebc8 100644 --- a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -12,7 +12,6 @@ #include #include -#include #include #include @@ -34,24 +33,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { min_values_kernel_cuda_impl(iter); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { gpu_reduce_kernel( iter, MinOps{}, thrust::pair(at::numeric_limits::upper_bound(), 0)); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] { + AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { min_values_kernel_cuda_impl(iter); - }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); + }); } REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index ecd2235b1445f..825a54a2ae4c4 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14311,7 +14311,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('max', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14320,7 +14320,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True), OpInfo('max', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), supports_out=True, supports_forward_ad=True, @@ -14465,7 +14465,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False,), OpInfo('min', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14474,7 +14474,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('min', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14784,7 +14784,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), - dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), decorators=(onlyNativeDeviceTypes,), supports_autograd=False, @@ -21127,7 +21127,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), ref=reference_reduction_numpy(np.amax), skips=( # FIXME: reduces all dimensions when dim=[] @@ -21142,7 +21142,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), ref=reference_reduction_numpy(np.amin), skips=( # FIXME: reduces all dimensions when dim=[] From 4c3721fe70931027d3ded6fc6d9279a7f4127e7d Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 7 Nov 2025 10:54:12 -0800 Subject: [PATCH 312/651] allow sym_stride, and sym_size lowering in inductor to return ints (#167345) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167345 Approved by: https://github.com/eellison --- torch/_inductor/lowering.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 7946f9ae67ad8..4016390c1b9e3 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7099,29 +7099,12 @@ def sym_constrain_range(a, min=None, max=None): @register_lowering(aten.sym_size.int) def sym_size(a, dim): val = V.graph.current_node.meta["val"] - # Note [Can val be an int?] - # ~~~~~~~~~~~~~~~~~~~~~~~~~ - # In principle, someone could construct an FX graph where - # a call to size/stride has a val that is a plain int (not - # SymInt). However, we will maintain the invariant that - # this is not possible: if you are constructing an FX graph - # where there is a call to size/stride that returns an - # int, but you KNOW that int must always be a constant, - # then you do not need trace that call at all (and just - # constant propagate the integer as is.) - assert isinstance(val, torch.SymInt), ( - f"Expect val to be torch.SymInt but got val={val}" - ) return val.node.expr @register_lowering(aten.sym_stride.int) def sym_stride(a, dim): val = V.graph.current_node.meta["val"] - # See Note [Can val be an int?] - assert isinstance(val, torch.SymInt), ( - f"Expect val to be torch.SymInt but got val={val}" - ) return val.node.expr From 9d9e7c7b1c69d951a373ece5c33df9a1fe18d769 Mon Sep 17 00:00:00 2001 From: Nicolas De Carli Date: Mon, 10 Nov 2025 23:36:57 +0000 Subject: [PATCH 313/651] [Pytorch] Extend OSS conversion benchmarks (#167099) Summary: We are extending OSS conversion benchmarks, to include all combinations between types Test Plan: CI Differential Revision: D86315975 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167099 Approved by: https://github.com/mcfi --- ...i_operator_benchmark_eager_float32_cpu.csv | 118 +++++++++++++++--- .../operator_benchmark/pt/tensor_to_test.py | 110 ++++++++-------- ...i_operator_benchmark_eager_float32_cpu.csv | 118 +++++++++++++++--- 3 files changed, 260 insertions(+), 86 deletions(-) diff --git a/benchmarks/operator_benchmark/aarch64_expected_ci_operator_benchmark_eager_float32_cpu.csv b/benchmarks/operator_benchmark/aarch64_expected_ci_operator_benchmark_eager_float32_cpu.csv index dc8b240ce570f..f3d8c7e65af04 100644 --- a/benchmarks/operator_benchmark/aarch64_expected_ci_operator_benchmark_eager_float32_cpu.csv +++ b/benchmarks/operator_benchmark/aarch64_expected_ci_operator_benchmark_eager_float32_cpu.csv @@ -484,24 +484,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,False,50.954394,0.000000 PyTorch,sum,sum_R256_V512_dim0_contiguousFalse_cpu,short,False,57.957757,0.000000 PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,False,53.592068,0.000000 PyTorch,sum,sum_R256_V512_dim1_contiguousFalse_cpu,short,False,51.339726,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,False,7.040985,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,False,7.168604,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,False,7.434442,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,False,7.078318,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,False,7.426670,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,False,7.679027,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,False,7.281365,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,False,7.682783,0.000000 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,False,8.381938,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,False,7.039854,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,False,7.399855,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,False,7.715193,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,False,7.255140,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,False,7.753522,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,False,8.364281,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,False,7.476377,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,False,8.458564,0.000000 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,False,9.391939,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.927,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.261,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.351,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.177,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,6.333,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,6.588,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,8.117,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,9.358,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,7.844,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,8.097,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.159,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.926,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.192,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.276,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,6.461,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,6.524,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,8.136,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.854,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,6.446,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,6.829,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.088,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.059,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.922,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.263,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,6.330,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,6.688,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,8.176,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.959,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,6.430,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,6.818,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.350,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.193,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.922,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.263,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,6.525,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,7.960,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.801,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,6.594,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,7.089,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.498,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.358,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.390,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.415,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.925,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,6.657,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,7.954,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.930,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,6.737,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,6.948,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.757,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.402,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.550,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.518,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,6.766,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.929,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,8.557,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,9.045,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,7.672,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,7.276,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,6.414,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,7.736,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,7.889,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,8.170,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,7.783,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,7.743,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.927,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,7.018,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,8.428,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,6.767,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.479,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,7.827,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.450,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.320,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,6.385,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,8.119,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,8.063,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.925,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,8.629,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,6.638,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.425,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.803,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.502,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.429,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,6.549,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,7.749,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,7.301,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.682,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.930,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,6.738,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,6.798,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,6.506,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,6.494,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,6.668,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,6.696,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,7.115,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.910,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.410,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,6.868,0.000000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.924,0.000000 PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.float32,short,False,4.461410,0.000000 PyTorch,addcmul,addcmul_M1_N2_cpu_dtypetorch.bfloat16,short,False,4.560082,0.000000 PyTorch,addcmul,addcmul_M32_N64_cpu_dtypetorch.float32,short,False,5.141248,0.000000 diff --git a/benchmarks/operator_benchmark/pt/tensor_to_test.py b/benchmarks/operator_benchmark/pt/tensor_to_test.py index 621e58212cba2..9354c8c52eaa8 100644 --- a/benchmarks/operator_benchmark/pt/tensor_to_test.py +++ b/benchmarks/operator_benchmark/pt/tensor_to_test.py @@ -4,74 +4,84 @@ tensor_conversion_short_configs = op_bench.cross_product_configs( - M=( - 8, - 16, - 32, - ), - N=( - 16, - 64, - 128, - ), + M=[32], + N=[128], device=["cpu", "cuda"], + dtype_one=[ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.bfloat16, + torch.float, + torch.double, + ], + dtype_two=[ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.bfloat16, + torch.float, + torch.double, + ], tags=["short"], ) tensor_conversion_long_configs = op_bench.cross_product_configs( - M=( - 64, - 128, - 256, - 512, - ), - N=( - 256, - 512, - 1024, - 2048, - ), + M=[1024], + N=[1024], device=["cpu", "cuda"], + dtype_one=[ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.bfloat16, + torch.float, + torch.double, + ], + dtype_two=[ + torch.bool, + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.half, + torch.bfloat16, + torch.float, + torch.double, + ], tags=["long"], ) -class FloatToHalfTensorConversionBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, device): +class TensorConversionBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, dtype_one, dtype_two, device): self.inputs = { "input": torch.rand( M, N, device=device, requires_grad=False, dtype=torch.float - ) + ).to(dtype=dtype_one) } + self.dtype_one = dtype_one + self.dtype_two = dtype_two def forward(self, input): - return input.to(torch.half) + return input.to(dtype=self.dtype_two) -class HalfToFloatTensorConversionBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, device): - self.inputs = { - "input": torch.rand( - M, N, device=device, requires_grad=False, dtype=torch.half - ) - } - - def forward(self, input): - return input.to(torch.float) - - -op_bench.generate_pt_test( - tensor_conversion_short_configs, FloatToHalfTensorConversionBenchmark -) -op_bench.generate_pt_test( - tensor_conversion_long_configs, FloatToHalfTensorConversionBenchmark -) -op_bench.generate_pt_test( - tensor_conversion_short_configs, HalfToFloatTensorConversionBenchmark -) -op_bench.generate_pt_test( - tensor_conversion_long_configs, HalfToFloatTensorConversionBenchmark -) +op_bench.generate_pt_test(tensor_conversion_short_configs, TensorConversionBenchmark) +op_bench.generate_pt_test(tensor_conversion_long_configs, TensorConversionBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/x86_64_expected_ci_operator_benchmark_eager_float32_cpu.csv b/benchmarks/operator_benchmark/x86_64_expected_ci_operator_benchmark_eager_float32_cpu.csv index d7a8e65aa85af..71a5930a01a3f 100644 --- a/benchmarks/operator_benchmark/x86_64_expected_ci_operator_benchmark_eager_float32_cpu.csv +++ b/benchmarks/operator_benchmark/x86_64_expected_ci_operator_benchmark_eager_float32_cpu.csv @@ -349,24 +349,106 @@ PyTorch,sum,sum_R256_V512_dim0_contiguousTrue_cpu,short,FALSE,12.5841 PyTorch,sum,sum_R256_V512_dim0_contiguousFALSE_cpu,short,FALSE,20.8765 PyTorch,sum,sum_R256_V512_dim1_contiguousTrue_cpu,short,FALSE,15.4414 PyTorch,sum,sum_R256_V512_dim1_contiguousFALSE_cpu,short,FALSE,15.3287 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0499 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3229 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4418 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.0868 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4495 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5578 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.2631 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5646 -PyTorch,FloatToHalfTensorConversionBenchmark,FloatToHalfTensorConversionBenchmark_M32_N128_cpu,short,FALSE,5.7898 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N16_cpu,short,FALSE,5.0228 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N64_cpu,short,FALSE,5.3692 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M8_N128_cpu,short,FALSE,5.4006 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N16_cpu,short,FALSE,5.1107 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N64_cpu,short,FALSE,5.4119 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M16_N128_cpu,short,FALSE,5.5583 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N16_cpu,short,FALSE,5.3818 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N64_cpu,short,FALSE,5.5742 -PyTorch,HalfToFloatTensorConversionBenchmark,HalfToFloatTensorConversionBenchmark_M32_N128_cpu,short,FALSE,6.8414 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bool,short,False,0.797 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.uint8,short,False,6.071 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int8,short,False,6.031 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int16,short,False,6.243 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int32,short,False,7.231 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.int64,short,False,7.791 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float16,short,False,12.661 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.bfloat16,short,False,11.225 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float32,short,False,9.772 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bool_dtype_twotorch.float64,short,False,9.872 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bool,short,False,6.033 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.uint8,short,False,0.781 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int8,short,False,6.060 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int16,short,False,6.180 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int32,short,False,7.258 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.int64,short,False,7.758 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float16,short,False,10.504 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.bfloat16,short,False,6.749 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float32,short,False,7.679 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.uint8_dtype_twotorch.float64,short,False,7.797 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bool,short,False,6.019 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.uint8,short,False,6.079 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int8,short,False,0.785 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int16,short,False,6.188 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int32,short,False,7.288 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.int64,short,False,7.770 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float16,short,False,10.466 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.bfloat16,short,False,6.676 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float32,short,False,7.736 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int8_dtype_twotorch.float64,short,False,7.780 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bool,short,False,6.130 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.uint8,short,False,6.221 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int8,short,False,6.101 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int16,short,False,0.791 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int32,short,False,6.254 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.int64,short,False,7.733 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float16,short,False,10.562 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.bfloat16,short,False,6.704 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float32,short,False,7.819 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int16_dtype_twotorch.float64,short,False,8.276 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bool,short,False,6.361 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.uint8,short,False,6.364 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int8,short,False,6.309 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int16,short,False,6.362 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int32,short,False,0.791 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.int64,short,False,7.746 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float16,short,False,9.462 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.bfloat16,short,False,6.678 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float32,short,False,7.827 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int32_dtype_twotorch.float64,short,False,8.200 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bool,short,False,6.925 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.uint8,short,False,6.947 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int8,short,False,6.962 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int16,short,False,6.906 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int32,short,False,7.664 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.int64,short,False,0.782 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float16,short,False,10.528 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.bfloat16,short,False,10.123 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float32,short,False,9.234 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.int64_dtype_twotorch.float64,short,False,8.694 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bool,short,False,12.653 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.uint8,short,False,9.348 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int8,short,False,8.774 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int16,short,False,9.063 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int32,short,False,10.012 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.int64,short,False,13.641 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float16,short,False,0.788 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.bfloat16,short,False,13.757 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float32,short,False,7.170 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float16_dtype_twotorch.float64,short,False,12.511 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bool,short,False,6.516 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.uint8,short,False,8.539 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int8,short,False,6.483 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int16,short,False,6.468 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int32,short,False,7.752 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.int64,short,False,9.868 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float16,short,False,10.556 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.bfloat16,short,False,0.792 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float32,short,False,7.577 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.bfloat16_dtype_twotorch.float64,short,False,8.267 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bool,short,False,6.819 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.uint8,short,False,7.715 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int8,short,False,6.754 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int16,short,False,6.825 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int32,short,False,7.790 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.int64,short,False,9.219 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float16,short,False,5.977 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.bfloat16,short,False,7.069 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float32,short,False,0.794 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float32_dtype_twotorch.float64,short,False,8.301 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bool,short,False,7.401 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.uint8,short,False,7.843 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int8,short,False,7.117 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int16,short,False,7.170 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int32,short,False,8.000 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.int64,short,False,9.284 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float16,short,False,7.179 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.bfloat16,short,False,7.645 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float32,short,False,7.988 +PyTorch,TensorConversionBenchmark,TensorConversionBenchmark_M32_N128_cpu_dtype_onetorch.float64_dtype_twotorch.float64,short,False,0.792 PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.quint8",short,FALSE,9.4657 PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint8",short,FALSE,9.4625 PyTorch,relu,"relu_dims(3,4,5)_contigFALSE_inplaceFALSE_dtypetorch.qint32",short,FALSE,9.4165 From e3d6896d08018d159920c363d1222db309bca71b Mon Sep 17 00:00:00 2001 From: anwang Date: Mon, 10 Nov 2025 09:50:55 -0800 Subject: [PATCH 314/651] [MTIAGraph][Pytorch][3/n] Implement mtia_graph python wrapper in pytorch (#166964) - Add python module `mtia_graph.py`, which is a wrapper on top of the c++ logic implemented in previous PRs/diffs - Add python level integration tests [Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb) Differential Revision: [D84673488](https://our.internmc.facebook.com/intern/diff/D84673488/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166964 Approved by: https://github.com/malfet --- docs/source/mtia.mtia_graph.md | 21 ++++++++ docs/source/pytorch-api.md | 1 + torch/_C/__init__.pyi.in | 10 ++++ torch/mtia/__init__.py | 3 ++ torch/mtia/mtia_graph.py | 96 ++++++++++++++++++++++++++++++++++ 5 files changed, 131 insertions(+) create mode 100644 docs/source/mtia.mtia_graph.md create mode 100644 torch/mtia/mtia_graph.py diff --git a/docs/source/mtia.mtia_graph.md b/docs/source/mtia.mtia_graph.md new file mode 100644 index 0000000000000..1d1560960792c --- /dev/null +++ b/docs/source/mtia.mtia_graph.md @@ -0,0 +1,21 @@ +# torch.mtia.mtia_graph + +The MTIA backend is implemented out of the tree, only interfaces are defined here. + +```{eval-rst} +.. automodule:: torch.mtia.mtia_graph +``` + +```{eval-rst} +.. currentmodule:: torch.mtia.mtia_graph +``` + +```{eval-rst} +.. autoclass:: MTIAGraph + :members: +``` + +```{eval-rst} +.. autoclass:: graph + :members: +``` diff --git a/docs/source/pytorch-api.md b/docs/source/pytorch-api.md index 5f99e4334bb69..c0f1302b8e8ed 100644 --- a/docs/source/pytorch-api.md +++ b/docs/source/pytorch-api.md @@ -29,6 +29,7 @@ mps xpu mtia mtia.memory +mtia.mtia_graph meta torch.backends torch.export diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 3fdf6302115b6..1af6df5e7664a 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2008,6 +2008,16 @@ def _mtia_attachOutOfMemoryObserver( def _mtia_getDeviceCount() -> _int: ... def _mtia_resetPeakMemoryStats(device: _int) -> None: ... +# Defined in torch/csrc/mtia/Module.cpp +class _MTIAGraph: + def __new__(cls, keep_graph: _bool = ...) -> Self: ... + def capture_begin(self, pool: tuple[_int, _int]) -> None: ... + def capture_end(self) -> None: ... + def instantiate(self) -> None: ... + def replay(self) -> None: ... + def reset(self) -> None: ... + def pool(self) -> tuple[_int, _int]: ... + # Defined in torch/csrc/mps/Module.cpp def _mps_deviceSynchronize() -> None: ... def _mps_get_core_count() -> _int: ... diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index c381d99747c0a..35ef04a67319d 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -396,6 +396,7 @@ def set_rng_state( from .memory import * # noqa: F403 +from .mtia_graph import * # noqa: F403 __all__ = [ @@ -424,4 +425,6 @@ def set_rng_state( "set_rng_state", "get_rng_state", "is_bf16_supported", + "MTIAGraph", + "graph", ] diff --git a/torch/mtia/mtia_graph.py b/torch/mtia/mtia_graph.py new file mode 100644 index 0000000000000..bc5a8ea49dfea --- /dev/null +++ b/torch/mtia/mtia_graph.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Optional, Union +from typing_extensions import Self + +import torch + + +_POOL_HANDLE = tuple[int, int] + + +class MTIAGraph(torch._C._MTIAGraph): + """ + Wrapper around a MTIA graph. + """ + + def __new__(cls, keep_graph: bool = False) -> Self: + return super().__new__(cls, keep_graph) + + def capture_begin(self, pool: _POOL_HANDLE) -> None: + """ + Begin capturing a MTIA graph. + """ + super().capture_begin(pool) + + def capture_end(self) -> None: + """ + End the capture of a MTIA graph. + """ + super().capture_end() + + def instantiate(self) -> None: + """ + Instantiate the captured MTIA graph. + """ + super().instantiate() + + def replay(self) -> None: + """ + Replay the captured MTIA graph. + """ + super().replay() + + def reset(self) -> None: + """ + Destroy the captured graph and reset the states. + """ + super().reset() + + def pool(self) -> _POOL_HANDLE: + """ + Return an opaque token representing the id of this graph's memory pool + """ + return super().pool() + + +class graph: + default_capture_stream: Optional[torch.mtia.Stream] = None + + def __init__( + self, + mtia_graph: MTIAGraph, + pool: Optional[_POOL_HANDLE] = None, + stream: Optional[torch.mtia.Stream] = None, + ): + if self.__class__.default_capture_stream is None: + self.__class__.default_capture_stream = torch.mtia.current_stream() + + self.pool: Union[tuple[()], tuple[_POOL_HANDLE]] = ( + () if pool is None else (pool,) + ) + self.capture_stream = ( + stream if stream is not None else self.__class__.default_capture_stream + ) + assert self.capture_stream is not None + self.stream_ctx = torch.mtia.stream(self.capture_stream) + self.mtia_graph = mtia_graph + + def __enter__(self) -> None: + torch.mtia.synchronize() + torch.mtia.empty_cache() + + self.stream_ctx.__enter__() + + pool_arg = self.pool[0] if self.pool else (0, 0) + self.mtia_graph.capture_begin(pool_arg) + + def __exit__(self, *args: object) -> None: + self.mtia_graph.capture_end() + self.stream_ctx.__exit__(*args) + + +__all__ = [ + "MTIAGraph", + "graph", +] From 5a85b6eaf839834bdc56477bd5d4b21279a9f503 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Mon, 10 Nov 2025 12:15:41 -0800 Subject: [PATCH 315/651] Migrate TypeTraits, TypeList, Metaprogramming to torch:: headeronly (#167386) Taking over #163634; adding tests/headeronly APIs Pull Request resolved: https://github.com/pytorch/pytorch/pull/167386 Approved by: https://github.com/albanD, https://github.com/mikaylagawarecki --- c10/util/Metaprogramming.cpp | 1 - c10/util/Metaprogramming.h | 225 +------ c10/util/TypeList.h | 516 +---------------- c10/util/TypeTraits.h | 152 +---- test/cpp/aoti_abi_check/CMakeLists.txt | 3 + .../aoti_abi_check/test_metaprogramming.cpp | 14 +- .../cpp/aoti_abi_check/test_typelist.cpp | 4 +- .../cpp/aoti_abi_check/test_typetraits.cpp | 4 +- torch/header_only_apis.txt | 38 ++ torch/headeronly/util/Metaprogramming.h | 237 ++++++++ torch/headeronly/util/TypeList.h | 548 ++++++++++++++++++ torch/headeronly/util/TypeTraits.h | 164 ++++++ 12 files changed, 1008 insertions(+), 898 deletions(-) delete mode 100644 c10/util/Metaprogramming.cpp rename c10/test/util/Metaprogramming_test.cpp => test/cpp/aoti_abi_check/test_metaprogramming.cpp (96%) rename c10/test/util/TypeList_test.cpp => test/cpp/aoti_abi_check/test_typelist.cpp (99%) rename c10/test/util/TypeTraits_test.cpp => test/cpp/aoti_abi_check/test_typetraits.cpp (98%) create mode 100644 torch/headeronly/util/Metaprogramming.h create mode 100644 torch/headeronly/util/TypeList.h create mode 100644 torch/headeronly/util/TypeTraits.h diff --git a/c10/util/Metaprogramming.cpp b/c10/util/Metaprogramming.cpp deleted file mode 100644 index f6ee24a79bcd8..0000000000000 --- a/c10/util/Metaprogramming.cpp +++ /dev/null @@ -1 +0,0 @@ -#include diff --git a/c10/util/Metaprogramming.h b/c10/util/Metaprogramming.h index d504706f3283a..a5912706e1ed1 100644 --- a/c10/util/Metaprogramming.h +++ b/c10/util/Metaprogramming.h @@ -1,224 +1 @@ -#pragma once - -#include -#include - -namespace c10::guts { - -/** - * Access information about result type or arguments from a function type. - * Example: - * using A = function_traits::return_type // A == int - * using A = function_traits::parameter_types::tuple_type - * // A == tuple - */ -template -struct function_traits { - static_assert( - !std::is_same_v, - "In function_traits, Func must be a plain function type."); -}; -template -struct function_traits { - using func_type = Result(Args...); - using return_type = Result; - using parameter_types = typelist::typelist; - static constexpr auto number_of_parameters = sizeof...(Args); -}; - -/** - * infer_function_traits: creates a `function_traits` type for a simple - * function (pointer) or functor (lambda/struct). Currently does not support - * class methods. - */ - -template -struct infer_function_traits { - using type = function_traits< - c10::guts::detail::strip_class_t>; -}; - -template -struct infer_function_traits { - using type = function_traits; -}; - -template -struct infer_function_traits { - using type = function_traits; -}; - -template -using infer_function_traits_t = typename infer_function_traits::type; - -/** - * make_function_traits: creates a `function_traits` type given a Return type - * and a typelist of Argument types - * - * Example: - * bool f(int, int); - * - * infer_function_traits_t == make_function_traits_t> - */ -template -struct make_function_traits { - static_assert( - false_t::value, - "In guts::make_function_traits, the ArgList argument must be typelist<...>."); -}; - -template -struct make_function_traits> { - using type = function_traits; -}; - -template -using make_function_traits_t = - typename make_function_traits::type; - -/** - * make_offset_index_sequence - * Like make_index_sequence, but starting from Start instead of 0. - * - * Example: - * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12> - */ -template -struct make_offset_index_sequence_impl - : make_offset_index_sequence_impl { - static_assert( - static_cast(Start) >= 0, - "make_offset_index_sequence: Start < 0"); - static_assert(static_cast(N) >= 0, "make_offset_index_sequence: N < 0"); -}; - -template -struct make_offset_index_sequence_impl { - typedef std::index_sequence type; -}; - -template -using make_offset_index_sequence = - typename make_offset_index_sequence_impl::type; - -/** - * Use tuple_elements to extract a position-indexed subset of elements - * from the argument tuple into a result tuple. - * - * Example: - * std::tuple t = std::make_tuple(0, "HEY", 2.0); - * std::tuple result = tuple_elements(t, std::index_sequence<0, - * 2>()); - */ -template -constexpr auto tuple_elements(Tuple t, std::index_sequence /*unused*/) { - return std::tuple...>(std::get(t)...); -} - -/** - * Use tuple_take to extract the first or last n elements from the argument - * tuple into a result tuple. - * - * Example: - * std::tuple t = std::make_tuple(0, "HEY", 2.0); - * std::tuple first_two = tuple_take(t); - * std::tuple last_two = tuple_take(t); - */ -template -struct TupleTake {}; - -template -struct TupleTake= 0, void>> { - static auto call(Tuple t) { - constexpr size_t size = std::tuple_size(); - static_assert(N <= size, "tuple_take: N > size"); - return tuple_elements(t, std::make_index_sequence{}); - } -}; - -template - struct TupleTake < Tuple, - N, std::enable_if_t> { - static auto call(Tuple t) { - constexpr size_t size = std::tuple_size(); - static_assert(-N <= size, "tuple_take: -N > size"); - return tuple_elements(t, make_offset_index_sequence{}); - } -}; - -template -auto tuple_take(Tuple t) { - return TupleTake::call(t); -} - -/** - * Use tuple_slice to extract a contiguous subtuple from the argument. - * - * Example: - * std::tuple t = std::make_tuple(0, - * "HEY", 2.0, false); std::tuple middle_two = - * tuple_slice(t); - */ -template -constexpr auto tuple_slice(Tuple t) { - constexpr size_t size = std::tuple_size(); - static_assert(Start + N <= size, "tuple_slice: Start + N > size"); - return tuple_elements(t, make_offset_index_sequence{}); -} - -/** - * Use tuple_map to run a mapping function over a tuple to get a new tuple. - * - * Example 1: - * auto result = tuple_map(std::tuple(3, 4, 5), [] - * (int32_t a) -> int16_t {return a+1;}); - * // result == std::tuple(4, 5, 6) - * - * Example 2: - * struct Mapper { - * std::string operator()(int32_t a) const { - * return std::to_string(a); - * } - * int64_t operator()(const std::string& a) const { - * return atoi(a.c_str()); - * } - * }; - * auto result = tuple_map(std::tuple(3, "4"), - * Mapper()); - * // result == std::tuple("3", 4) - * - * Example 3: - * struct A final { - * int32_t func() { - * return 5; - * } - * }; - * struct B final { - * std::string func() { - * return "5"; - * } - * }; - * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return - * a.func(); }); - * // result == std::tuple(5, "5"); - */ -namespace detail { -template -auto tuple_map( - // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) - std::tuple&& tuple, - const Mapper& mapper, - std::index_sequence /*unused*/) { - return std::tuple(std::get( - tuple))))...>(mapper(std::forward(std::get(tuple)))...); -} -} // namespace detail - -template -auto tuple_map(std::tuple&& tuple, const Mapper& mapper) { - return detail::tuple_map( - std::move(tuple), mapper, std::index_sequence_for()); -} - -} // namespace c10::guts +#include diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index 244e5bb141cd7..9f79099710d71 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -1,515 +1 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace c10::guts { - -template -struct false_t : std::false_type {}; -template